Compare commits

..

37 Commits

Author SHA1 Message Date
riccardom
7706f578fe use API 2026-06-18 19:25:55 +02:00
riccardom
daf5026192 Adds restart for MDM 2026-06-18 19:23:41 +02:00
riccardom
ec18b07959 NOP 2026-06-18 17:27:18 +02:00
riccardom
9628f016da Bridges embed to use the supervisor + remove providing of establishedChan from caller
(We now do AsyncStart + waitEstablishedOrDone, so everything is managed inside the
supervisor)
2026-06-18 17:10:17 +02:00
riccardom
b39e9df194 Renaming 2026-06-18 16:51:06 +02:00
riccardom
0388e0f262 Merge branch 'main' into lock_removal
# Conflicts:
#	client/internal/connect.go
#	client/ios/NetBirdSDK/client.go
#	client/server/server.go
#	client/server/server_test.go
2026-06-18 16:10:40 +02:00
riccardom
86f896723d Wait on broadcasted ended signal for establishment / done 2026-06-18 15:20:11 +02:00
riccardom
29ee84999c conn established (success) or done (end/failure..) are signals of the supervisor 2026-06-18 14:40:03 +02:00
riccardom
0e8fd22f36 Highlight what signals that the sup is running (which means the Connection is running because of UP/auto start) 2026-06-18 14:40:03 +02:00
riccardom
ff98105212 Clarifies: service -> ServiceRunning -> up -> ConnectionRunning -> connestablished ->connEstablished -> end of run -> connDone 2026-06-18 14:40:03 +02:00
riccardom
6465997a69 Aligns tests 2026-06-18 14:40:03 +02:00
riccardom
3204270c4b Removes all unrequired checks, since the lifetime now guarantees the non nil presence of connectClient 2026-06-18 14:40:03 +02:00
riccardom
6d3bcef2c4 Rename to clarify 2026-06-18 14:40:03 +02:00
riccardom
5d7cb30e5b Removes other occurrencies of connectClient check 2026-06-18 14:40:03 +02:00
riccardom
aff5da2c8e Log something better when UP doesn't find service grpc socket 2026-06-18 14:40:03 +02:00
riccardom
9b179be324 Defines an API for knowing if the SERVICE is running (regardless of up and down state)
New() builds s.connecClient and is called when the gRPC service is started.
Up() is invoked only IF a gRPC service IS running which is possible only if the New() was
called.
2026-06-17 23:30:30 +02:00
riccardom
33e7b6a8f1 Align tests 2026-06-17 23:00:53 +02:00
riccardom
e0cff5e240 If New creates, Starts MUST find a connectClient 2026-06-17 23:00:53 +02:00
riccardom
0085aebf77 Removes external connectWithRetryRuns 2026-06-17 08:52:02 +02:00
riccardom
91d2d341b7 Guard is done inside not from external. Stop called unconditionally 2026-06-16 23:18:50 +02:00
riccardom
8d46580c13 Not the run duty 2026-06-16 23:09:59 +02:00
riccardom
b42fe6a10f And now let's just avoid it at all 2026-06-16 21:58:05 +02:00
riccardom
0f5d7fdc07 Removes deadlock 2026-06-16 21:57:07 +02:00
riccardom
13c78d98f5 Client (and the supervisor within) now lives forever.
So checkin that it's nil isn't anymore an indirect
way to know the cleanup has succeeded
2026-06-16 18:55:08 +02:00
riccardom
d1229ed84c Restores context MD passed to GetInfo to mgmt (is it valuable data?) 2026-06-16 18:55:08 +02:00
riccardom
9758145517 Discriminate auth fails from mgmt unreachable.
Needed to avoid this workaround

if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
     s.actCancel()
}

in server.go Status(...)
2026-06-16 18:55:08 +02:00
riccardom
200a5a6a70 Rename. We keep only start command (long lasting command)
For stop we do synchronously.
2026-06-16 18:55:08 +02:00
riccardom
1f7b1ea863 IsRunning V1 2026-06-16 18:55:08 +02:00
riccardom
4abb10c1aa Fixes DisableAutoConnect semantics
DisableAutoConnect semantics
============================

Scope: governs ONLY the service-Start auto-connect decision. Once the
connection goroutine has been spawned (by any path), the flag is never
consulted again — the retry loop keeps trying to connect until ctx is
cancelled by Down / Stop / Logout.

Cases
-----

1. Service Start + DisableAutoConnect = true
   - No connection goroutine spawned.
   - clientRunning stays false.
   - State set to StatusIdle.
   - Daemon stays passive until an explicit Up RPC.

2. Service Start + DisableAutoConnect = false
   - Spawn connectWithRetryRuns.
   - clientRunning = true.
   - Retry loop runs until ctx cancelled.

3. Up RPC (any value of DisableAutoConnect)
   - Flag ignored. The user / admin explicitly asked to connect — by
     definition not "auto".
   - Spawn connectWithRetryRuns.
   - clientRunning = true.

4. MDM-triggered restart (any value of DisableAutoConnect)
   - Flag ignored. An MDM policy change applies new config to an
     already-running engine; treated as an implicit Up.
   - Spawn connectWithRetryRuns.
   - clientRunning = true.

5. Down / Stop / Logout
   - Cancels ctx → connectWithRetryRuns exits → close(giveUpChan).
   - cleanupConnection clears clientRunning = false.
   - DisableAutoConnect not involved.
Prepare for next step

Collapse error log into s.connect and rename it to more explicit connectOnce

# Conflicts:
#	client/server/server.go
2026-06-16 18:55:08 +02:00
riccardom
a45cefe57a IsRunning V0 2026-06-16 18:55:08 +02:00
riccardom
a6d504633f Use Stop not other direct calls like actCancel() things.. for now adding then removing 2026-06-16 18:55:08 +02:00
riccardom
70f2097fff config becomes start/run arg so we are sure it gets updated on any (re)start
This is because we now have the supervisor not destroyed over time

Still intermediate step where server.go regenerates the client entirely.
2026-06-16 18:31:46 +02:00
riccardom
befa9a879c DO NOT recreate ConnectClient! 2026-06-16 18:31:46 +02:00
riccardom
4152c41796 Wire the stop-for-any-reason to the sup stop (and remove race on engine!) 2026-06-16 18:31:46 +02:00
riccardom
8b76b3d824 Keep the command DON'T DO COPIES!!! 2026-06-16 18:31:46 +02:00
riccardom
0503a18644 Wraps client lifetime into a supervisor
- define needed supervisor context/variables
- will use runCancel as knob to know if the client is running. No extra boolean flags
- runWaiter is used to signal to the async run caller
2026-06-16 18:31:46 +02:00
riccardom
ec6512d660 Docker build env for windows/android 2026-06-16 15:31:52 +02:00
118 changed files with 1925 additions and 5305 deletions

View File

@@ -20,7 +20,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
@@ -59,12 +59,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: true

View File

@@ -15,7 +15,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1

View File

@@ -16,18 +16,18 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
@@ -48,7 +48,7 @@ jobs:
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
@@ -28,7 +28,7 @@ jobs:
id: test
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
with:
usesh: true
copyback: false

View File

@@ -18,7 +18,7 @@ jobs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
@@ -30,7 +30,7 @@ jobs:
- 'management/**'
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -41,7 +41,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
id: cache
with:
path: |
@@ -119,12 +119,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -135,7 +135,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -162,7 +162,7 @@ jobs:
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
@@ -175,12 +175,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -192,7 +192,7 @@ jobs:
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
id: cache-restore
with:
path: |
@@ -246,12 +246,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -266,7 +266,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -290,7 +290,7 @@ jobs:
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
@@ -306,12 +306,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -325,7 +325,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -347,7 +347,7 @@ jobs:
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
@@ -363,12 +363,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -383,7 +383,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -407,7 +407,7 @@ jobs:
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
@@ -424,12 +424,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -440,7 +440,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -484,7 +484,7 @@ jobs:
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
@@ -529,12 +529,12 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -545,7 +545,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -579,11 +579,10 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
env:
GIT_BRANCH: ${{ github.ref_name }}
api_benchmark:
name: "Management / Benchmark (API)"
@@ -624,12 +623,12 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -640,7 +639,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -674,13 +673,12 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/server/http/...
env:
GIT_BRANCH: ${{ github.ref_name }}
api_integration_test:
name: "Management / Integration"
@@ -694,12 +692,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -710,7 +708,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -736,7 +734,7 @@ jobs:
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird

View File

@@ -18,12 +18,12 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
id: go
with:
go-version-file: "go.mod"
@@ -35,7 +35,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: codespell
@@ -40,7 +40,7 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check for duplicate constants
@@ -48,7 +48,7 @@ jobs:
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false

View File

@@ -22,7 +22,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false

View File

@@ -16,11 +16,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Setup Android SDK
@@ -28,13 +28,13 @@ jobs:
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@1bcf9fb12cf4aa7d266a90ae39939e61372fe520
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -54,11 +54,11 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: install gomobile

View File

@@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
@@ -64,7 +64,7 @@ jobs:
if: steps.check_diff.outputs.diff_exists == 'true'
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
with:
usesh: true
copyback: false
@@ -135,7 +135,7 @@ jobs:
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
steps:
- name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
@@ -166,12 +166,12 @@ jobs:
fi
- name: Set up Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
~/go/pkg/mod
@@ -186,9 +186,9 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up QEMU
uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
- name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
@@ -221,7 +221,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with:
version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }}
@@ -347,7 +347,7 @@ jobs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps:
- name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
@@ -374,12 +374,12 @@ jobs:
fi
- name: Set up Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
~/go/pkg/mod
@@ -420,7 +420,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
@@ -464,17 +464,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
~/go/pkg/mod
@@ -488,7 +488,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
@@ -522,7 +522,7 @@ jobs:
downloadPath: '${{ github.workspace }}\temp'
steps:
- name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
@@ -534,13 +534,13 @@ jobs:
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
with:
name: release
path: release
- name: Download UI release artifacts
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
with:
name: release-ui
path: release-ui

View File

@@ -68,17 +68,17 @@ jobs:
run: sudo apt-get install -y curl
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -207,7 +207,7 @@ jobs:
- name: Build management docker image
working-directory: management
run: |
docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. .
docker build -t netbirdio/management:latest .
- name: Build signal binary
working-directory: signal
@@ -216,7 +216,7 @@ jobs:
- name: Build signal docker image
working-directory: signal
run: |
docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. .
docker build -t netbirdio/signal:latest .
- name: Build relay binary
working-directory: relay
@@ -225,7 +225,7 @@ jobs:
- name: Build relay docker image
working-directory: relay
run: |
docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. .
docker build -t netbirdio/relay:latest .
- name: run docker compose up
working-directory: infrastructure_files/artifacts
@@ -256,7 +256,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false

View File

@@ -19,11 +19,11 @@ jobs:
GOARCH: wasm
steps:
- name: Checkout repository
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Install dependencies
@@ -44,11 +44,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Build Wasm client

View File

@@ -247,7 +247,7 @@ dockers_v2:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile
extra_files:
@@ -295,7 +295,7 @@ dockers_v2:
- netbirdio/relay
- ghcr.io/netbirdio/relay
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: relay/Dockerfile
platforms:
@@ -317,7 +317,7 @@ dockers_v2:
- netbirdio/signal
- ghcr.io/netbirdio/signal
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: signal/Dockerfile
platforms:
@@ -339,7 +339,7 @@ dockers_v2:
- netbirdio/management
- ghcr.io/netbirdio/management
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: management/Dockerfile
platforms:
@@ -361,7 +361,7 @@ dockers_v2:
- netbirdio/upload
- ghcr.io/netbirdio/upload
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: upload-server/Dockerfile
platforms:
@@ -383,7 +383,7 @@ dockers_v2:
- netbirdio/netbird-server
- ghcr.io/netbirdio/netbird-server
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: combined/Dockerfile
platforms:
@@ -405,7 +405,7 @@ dockers_v2:
- netbirdio/reverse-proxy
- ghcr.io/netbirdio/reverse-proxy
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: proxy/Dockerfile
platforms:
@@ -462,13 +462,9 @@ checksum:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh
- glob: ./infrastructure_files/getting-started-enterprise.sh
- glob: ./infrastructure_files/migrate-to-enterprise.sh
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh
- glob: ./infrastructure_files/getting-started-enterprise.sh
- glob: ./infrastructure_files/migrate-to-enterprise.sh

View File

@@ -37,11 +37,6 @@
</strong>
</p>
> ### 🤖 NetBird Agent Network (Beta)
> Identity-aware access control for AI agents — keyless access to LLM APIs and private
> resources over the encrypted NetBird tunnel. See [`agent-network/`](agent-network/) or
> read the docs at **[netbird.ai](https://netbird.ai)**.
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.

View File

@@ -1,39 +0,0 @@
# NetBird Agent Network
Agent Network is NetBird's access control layer for AI agents and the people who run
them. It gives every agent a real identity, tied to your identity provider (IdP), and
governs what it can reach — the LLM APIs and AI gateways it can call, and the internal
resources it can access. Traffic flows only over the encrypted NetBird tunnel, scoped by
policy, with no API keys to leak.
> **Beta.** Agent Network is open source and can be self-hosted on your own
> infrastructure.
## How it works
Agent Network is built on two existing NetBird capabilities:
- **Overlay network** — the encrypted WireGuard mesh between peers.
- **Reverse proxy** — a NetBird peer that terminates LLM requests, establishes the
caller's identity, evaluates policies/limits/guardrails, injects the upstream provider
key server-side, forwards to the API or gateway, and records usage.
LLM traffic is routed through the proxy's identity-aware pipeline, while internal
resources (databases, internal APIs, self-hosted models) are reached directly over
peer-to-peer WireGuard tunnels, governed by the same identities and access policies.
## Where the code lives
There is no separate "agent-network" service — it reuses the reverse-proxy and management
components:
- [`proxy/`](../proxy) — the NetBird reverse proxy that serves the agent network endpoint
and runs the per-request middleware pipeline.
- [`management/internals/modules/reverseproxy/`](../management/internals/modules/reverseproxy)
— the management-side control plane: providers, policies, guardrails, limits, routing,
and usage/access logs.
## Documentation
Full documentation, architecture, and quickstart:
**https://docs.netbird.io/agent-network**

View File

@@ -151,9 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
connectClient := internal.NewConnectClient(ctx, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -186,9 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
connectClient := internal.NewConnectClient(ctx, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
}
// Stop the internal client and free the resources

View File

@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
ProfileName: string(activeProf.ID),
ProfileName: activeProf.Name,
Username: currUser.Username,
})
if err != nil {

View File

@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
Username: &username,
})
if err != nil {
return "", fmt.Errorf("switch profile failed: %w", err)
return "", fmt.Errorf("switch profile failed: %v", err)
}
return profilemanager.ID(resp.Id), nil

View File

@@ -138,23 +138,26 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
return err
}
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return err
return fmt.Errorf("add profile request: %w", err)
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
@@ -163,6 +166,7 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
@@ -326,19 +330,3 @@ func wrapAmbiguityError(err error, handle string) error {
}
return err
}
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
// and returns the new profile's ID. It is the single entry point for profile
// creation, shared by `netbird profile add` and the `netbird up --profile
// <name>` auto-create path.
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
ProfileName: profileName,
Username: username,
})
if err != nil {
return "", fmt.Errorf("add profile failed: %w", err)
}
return profilemanager.ID(resp.Id), nil
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
@@ -261,17 +262,46 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
return prefix + upper
}
// DialClientGRPCServer returns client connection to the daemon server.
// DialClientGRPCServer returns client connection to the daemon server. It waits
// (up to the timeout) for the daemon to become reachable so an `up` issued right
// after `service start` tolerates the startup race. Instead of grpc's blocking
// dial — whose raw "transport failed" retry warnings are silenced by the logger
// config — we drive the wait ourselves and emit one clean line per failed attempt.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
return grpc.DialContext(
conn, err := grpc.DialContext(
ctx,
strings.TrimPrefix(addr, "tcp://"),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
if err != nil {
return nil, err
}
conn.Connect()
for {
state := conn.GetState()
if state == connectivity.Ready {
return conn, nil
}
// Log only once the connection has actually failed — not during the
// brief Idle/Connecting phase on a healthy daemon (avoids a spurious
// line + wait when the daemon is already up).
if state == connectivity.TransientFailure {
log.Infof("waiting for the netbird daemon to become available at %s...", addr)
}
// Wake on the next state change, but at least every second so a stuck
// TransientFailure re-logs at a steady cadence until the timeout.
waitCtx, waitCancel := context.WithTimeout(ctx, time.Second)
conn.WaitForStateChange(waitCtx, state)
waitCancel()
if ctx.Err() != nil {
_ = conn.Close()
return nil, fmt.Errorf("daemon not reachable at %s: %w", addr, ctx.Err())
}
}
}
// WithBackOff execute function in backoff cycle.

View File

@@ -11,6 +11,7 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
@@ -110,10 +111,11 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
// Resolve the active profile's display name via the daemon, which runs
// as root and can read the per-user profile files. The local profile
// manager only knows the active profile ID, not its display name.
profName := getActiveProfileName(ctx)
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag,
@@ -165,25 +167,6 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
return resp, nil
}
// getActiveProfileName asks the daemon for the active profile's display
// name. The daemon runs as root and can read the per-user profile files to
// resolve the ID to its human-readable name. Returns an empty string on any
// error so status output degrades gracefully.
func getActiveProfileName(ctx context.Context) string {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return ""
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
if err != nil {
return ""
}
return resp.GetProfileName()
}
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "idle", "connecting", "connected":

View File

@@ -128,9 +128,15 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool
// switch profile if provided
if profileName != "" {
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
profileSwitched = true
}
@@ -145,52 +151,6 @@ func upFunc(cmd *cobra.Command, args []string) error {
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
}
// switchOrCreateProfile switches the active profile to the one identified by
// handle, creating it first when it does not exist yet. This restores the
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
// missing profile instead of failing.
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
if err != nil {
st, ok := gstatus.FromError(err)
if !ok || st.Code() != codes.NotFound {
return err
}
// Don't fail immediately on a create error: a concurrent run may
// have created the profile between the NotFound above and this
// call, in which case the retried switch still succeeds. Only
// surface the create error if the switch also fails.
_, createErr := createProfile(ctx, handle, username)
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
if createErr != nil {
return fmt.Errorf("create profile: %w", createErr)
}
return err
}
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return err
}
return nil
}
// createProfile dials the daemon and creates a new profile with the given
// display name, returning its generated ID. Use addProfileOnDaemon directly
// when a daemon client is already available to reuse the connection.
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
}
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
// override the default profile filepath if provided
if configPath != "" {
@@ -241,10 +201,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
connectClient := internal.NewConnectClient(ctx, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
return connectClient.Run(config, nil, util.FindFirstLogPath(logFiles))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {

View File

@@ -264,34 +264,24 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
client := internal.NewConnectClient(ctx, c.config, c.recorder)
client := internal.NewConnectClient(ctx, c.recorder)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
// The supervisor owns the run; we wait until it is established, ends with a
// startup error (permanent backoff err), or startCtx expires.
// TODO: make after-startup backoff err available
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run, ""); err != nil {
clientErr <- err
}
}()
client.RunAsync(c.config, nil)
select {
case <-startCtx.Done():
// ConnectClient.Stop now cancels its own run context and waits for the
// run loop to tear the engine down, so this cancel() is no longer
// required to break the deadlock and could be removed. It is kept as a
// defensive belt-and-suspenders: cancelling the parent context first
// guarantees the run loop is unblocked even if Stop's contract regresses.
if err := client.WaitEstablishedOrDone(startCtx); err != nil {
// Either startCtx expired while connecting, or the run ended before it
// established. Cancel the client context before stopping: Engine.Start
// blocks on the signal stream while holding the engine mutex and only
// unblocks on cancellation. Stopping first would deadlock on that mutex.
cancel()
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
return fmt.Errorf("stop error after startup failure. Stop error: %w. Startup: %w", stopErr, err)
}
return startCtx.Err()
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
}
c.connect = client

View File

@@ -11,7 +11,6 @@ import (
"runtime/debug"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
@@ -19,6 +18,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -49,17 +49,23 @@ import (
"github.com/netbirdio/netbird/version"
)
// androidRunOverride is set on Android to inject mobile dependencies
// when using embed.Client (which calls Run() with empty MobileDependency).
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
// androidMobileDep is set on Android to inject the MobileDependency for runs
// started through the generic entry points (Run/RunAsync, e.g. embed.Client).
// nil on other platforms, where the dependency is empty.
var androidMobileDep func(config *profilemanager.Config) MobileDependency
// mobileDependency returns the MobileDependency for a run started via the
// generic entry points. On Android the androidMobileDep provider supplies
// platform stubs (or real implementations); elsewhere it is empty.
func (c *ConnectClient) mobileDependency(config *profilemanager.Config) MobileDependency {
if androidMobileDep != nil {
return androidMobileDep(config)
}
return MobileDependency{}
}
type ConnectClient struct {
ctx context.Context
runCancel context.CancelFunc
runExited chan struct{}
runOnce sync.Once
runStarted atomic.Bool
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
@@ -68,41 +74,62 @@ type ConnectClient struct {
updateManager *updater.Manager
persistSyncResponse bool
// sup serializes all start/stop requests so two lifecycle operations can
// never overlap. See connect_lifecycle.go.
sup *supervisor
}
func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
) *ConnectClient {
// Derive the run context here so Stop owns the cancel that unblocks the run
// loop. runCancel is set once at construction, so Stop can call it without
// racing the run loop's startup. Callers therefore need not cancel before Stop.
runCtx, runCancel := context.WithCancel(ctx)
return &ConnectClient{
ctx: runCtx,
runCancel: runCancel,
runExited: make(chan struct{}),
config: config,
c := &ConnectClient{
ctx: ctx,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
c.sup = newSupervisor(ctx, c.run)
return c
}
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
c.updateManager = um
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
if androidRunOverride != nil {
return androidRunOverride(c, runningChan, logPath)
}
return c.run(MobileDependency{}, runningChan, logPath)
// Run with main logic. md carries optional gRPC metadata (e.g. the UI
// user-agent) to forward to the management/signal services; nil when none.
func (c *ConnectClient) Run(config *profilemanager.Config, md metadata.MD, logPath string) error {
return c.sup.start(config, md, c.mobileDependency(config), logPath)
}
// RunAsync starts a client run without blocking. Used by the daemon and embed,
// which drive the lifecycle through the supervisor rather than blocking on Run;
// they then wait for the outcome via WaitEstablishedOrDone. The run's lifecycle
// channels are created and owned by the supervisor — callers never hold them.
func (c *ConnectClient) RunAsync(config *profilemanager.Config, md metadata.MD) {
c.sup.startAsync(config, md, c.mobileDependency(config), "", nil)
}
// Restart atomically stops any in-flight run and starts a fresh one with the
// given config. The stop+start happens as a single supervisor operation, so no
// other lifecycle request can interleave between them — used for explicit
// restarts (e.g. an MDM policy change) that must not expose a "stopped" window.
func (c *ConnectClient) Restart(config *profilemanager.Config, md metadata.MD) {
c.sup.restartAsync(config, md, c.mobileDependency(config), "")
}
// WaitEstablishedOrDone blocks until the in-flight run becomes established (nil),
// ends before that (the run error, or a sentinel on a clean stop), or ctx is
// cancelled. Returns errNoRunInFlight if no run is in flight. Wraps the wait on
// the supervisor-owned channels so callers never touch them directly.
func (c *ConnectClient) WaitEstablishedOrDone(ctx context.Context) error {
return c.sup.waitEstablishedOrDone(ctx)
}
// RunOnAndroid with main logic on mobile system
func (c *ConnectClient) RunOnAndroid(
config *profilemanager.Config,
tunAdapter device.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
@@ -121,10 +148,11 @@ func (c *ConnectClient) RunOnAndroid(
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, "")
return c.sup.start(config, nil, mobileDependency, "")
}
func (c *ConnectClient) RunOniOS(
config *profilemanager.Config,
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
@@ -142,15 +170,12 @@ func (c *ConnectClient) RunOniOS(
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, logFilePath)
return c.sup.start(config, nil, mobileDependency, logFilePath)
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
// Mark the loop as started and signal exit on return so Stop can wait for
// the loop to finish (and skip the wait if the loop never ran).
c.runStarted.Store(true)
defer c.runOnce.Do(func() { close(c.runExited) })
// run executes a single client run. runCtx is owned by the supervisor: cancelling
// it tears the run down (it is the parent of the per-attempt engine context).
func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Config, mobileDependency MobileDependency, connEstablishedChan chan struct{}, logPath string) error {
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -214,18 +239,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}()
wrapErr := state.Wrap
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey)
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error())
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
return wrapErr(err)
}
var mgmTlsEnabled bool
if c.config.ManagementURL.Scheme == "https" {
if config.ManagementURL.Scheme == "https" {
mgmTlsEnabled = true
}
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey))
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
@@ -259,13 +284,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
if c.ctx.Err() != nil {
if runCtx.Err() != nil {
return nil
}
state.Set(StatusConnecting)
engineCtx, cancel := context.WithCancel(c.ctx)
engineCtx, cancel := context.WithCancel(runCtx)
defer func() {
_, err := state.Status()
c.statusRecorder.MarkManagementDisconnected(err)
@@ -273,8 +298,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
cancel()
}()
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
}
@@ -291,7 +316,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
defer func() {
if err = mgmClient.Close(); err != nil {
log.Warnf("failed to close the Management service client %v", err)
@@ -300,13 +325,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
loginStarted := time.Now()
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, config)
if err != nil {
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
c.runCancel()
// No teardown needed: login fails before the engine is started
// (engine.Start is below), so there is nothing running to stop.
return backoff.Permanent(wrapErr(err)) // unrecoverable error
}
return wrapErr(err)
@@ -360,7 +386,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, logPath)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -404,7 +430,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine = engine
c.engineMutex.Unlock()
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
if err := engine.Start(loginResp.GetNetbirdConfig(), config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
@@ -412,12 +438,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
select {
case <-runningChan:
default:
close(runningChan)
}
// The supervisor owns connEstablishedChan and it is always present. Guard
// against a double close: operation re-runs on ErrResetConnection retries
// within the same run, and the channel is closed only on the first connect.
select {
case <-connEstablishedChan:
default:
close(connEstablishedChan)
}
<-engineCtx.Done()
@@ -426,8 +453,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine = nil
c.engineMutex.Unlock()
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
// Always tear the engine down once its context is cancelled. engine.Stop
// is nil-guarded per component, so calling it unconditionally is safe and
// avoids both the data race on engine.wgInterface and skipping teardown
// when the interface was never brought up (e.g. a mid-start failure).
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
@@ -445,12 +474,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
c.statusRecorder.ClientStart()
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
err = backoff.Retry(operation, backOff)
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// Login failed permanently: the engine was never started, so there
// is nothing to tear down — just record that a login is needed.
state.Set(StatusNeedsLogin)
c.runCancel()
}
return err
}
@@ -471,6 +501,22 @@ func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
return relayCfg.GetUrls(), token
}
// ConnectionRunning reports whether a connection run is currently in flight
// (connecting, connected, or reconnecting). Answered by the supervisor via a
// serialized query, so it settles behind an in-flight stop. Distinct from
// ServiceRunning, which reports whether the service itself is alive.
func (c *ConnectClient) ConnectionRunning() bool {
return c.sup.isRunning()
}
// ServiceRunning reports whether the client's lifecycle supervisor is alive and
// able to accept start/stop commands — i.e. its context has not been cancelled
// (the daemon is not shutting down). Independent of whether a connection run is
// up (that is ConnectionRunning).
func (c *ConnectClient) ServiceRunning() bool {
return c.sup.ctx.Err() == nil
}
func (c *ConnectClient) Engine() *Engine {
if c == nil {
return nil
@@ -527,12 +573,10 @@ func (c *ConnectClient) Status() StatusType {
return status
}
// Stop serializes a stop request through the lifecycle supervisor and blocks
// until the in-flight run is fully torn down.
func (c *ConnectClient) Stop() error {
c.runCancel()
if c.runStarted.Load() {
<-c.runExited
}
return nil
return c.sup.stop()
}
// SetSyncResponsePersistence enables or disables sync response persistence.

View File

@@ -7,6 +7,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
@@ -59,19 +60,17 @@ var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
var _ dns.ReadyListener = noopDnsReadyListener{}
func init() {
// Wire up the default override so embed.Client.Start() works on Android
// with netstack mode. Provides complete no-op stubs for all mobile
// Wire up the default MobileDependency provider so embed.Client.Start() works
// on Android with netstack mode. Provides complete no-op stubs for all mobile
// dependencies so the engine's existing Android code paths work unchanged.
// Applications that need P2P ICE or real DNS should replace this by
// setting androidRunOverride before calling Start().
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
return c.runOnAndroidEmbed(
// Applications that need P2P ICE or real DNS should replace this by setting
// androidMobileDep before calling Start().
androidMobileDep = func(config *profilemanager.Config) MobileDependency {
return mobileDependencyForEmbed(
noopIFaceDiscover{},
noopNetworkChangeListener{},
[]netip.AddrPort{},
noopDnsReadyListener{},
runningChan,
logPath,
)
}
}

View File

@@ -10,23 +10,18 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
// so embed.Client.Start() can detect when the engine is ready.
// It provides complete MobileDependency so the engine's existing
// Android code paths work unchanged.
func (c *ConnectClient) runOnAndroidEmbed(
// mobileDependencyForEmbed builds the MobileDependency used by embed.Client on
// Android so the engine's existing Android code paths work unchanged.
func mobileDependencyForEmbed(
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
runningChan chan struct{},
logPath string,
) error {
mobileDependency := MobileDependency{
) MobileDependency {
return MobileDependency{
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, runningChan, logPath)
}

View File

@@ -0,0 +1,362 @@
package internal
import (
"context"
"errors"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
// errAlreadyRunning is returned when a start is requested while a run is already
// in flight.
var errAlreadyRunning = errors.New("client is already running")
// errNoRunInFlight is returned by waitEstablishedOrDone when no run is active.
var errNoRunInFlight = errors.New("no connection run in flight")
// errStoppedBeforeEstablished is returned when a run ended (cleanly) before the
// connection was established.
var errStoppedBeforeEstablished = errors.New("run stopped before the connection was established")
// lifecycleOp is a serialized lifecycle operation processed by the supervisor.
type lifecycleOp int
const (
opStart lifecycleOp = iota
opStop
opRestart
opStatus
opWaitEstablished
)
// lifecycleCmd is a single lifecycle request handed to the supervisor goroutine.
// They all flow through the same cmdCh so they are strictly ordered (FIFO) with
// respect to each other.
type lifecycleCmd struct {
op lifecycleOp
config *profilemanager.Config
md metadata.MD
mobileDep MobileDependency
logPath string
// done is the caller's notification channel (nil for fire-and-forget). Its
// meaning depends on op:
// - opStart: receives the run's end result when the run terminates, or
// errAlreadyRunning immediately if a run is already in flight.
// - opStop: receives nil once the in-flight run has fully unwound.
// - opWaitEstablished: receives the wait outcome (see waitEstablishedOrDone).
done chan error
reply chan bool // opStatus only: receives whether a run is in flight
waitCtx context.Context // opWaitEstablished only: the waiter's cancellation context
}
// runState holds the lifecycle channels of a single in-flight run, owned by the
// loop goroutine. It never escapes the supervisor as an API; the only readers
// are the per-wait goroutines the loop spawns for opWaitEstablished.
//
// connEstablishedChan is closed by the run once the connection is established.
// The supervisor creates and owns it — callers no longer supply it; they observe
// it through waitEstablishedOrDone. ended is closed (broadcast) when the run
// terminates, so any number of waiters can observe it; err is the run's end
// result, valid only after ended is closed.
type runState struct {
connEstablishedChan chan struct{} // closed by the run on established
ended chan struct{} // closed by finishRun when the run terminates
err error // run end result, valid after ended is closed
}
// runEndResult is sent by the run goroutine to the supervisor when a run ends,
// whether on its own (error / external context cancellation) or because of a Stop.
type runEndResult struct {
err error
}
// runFunc executes a single client run bound to the supervisor-owned context,
// with the config supplied by the start request.
type runFunc func(ctx context.Context, config *profilemanager.Config, mobileDep MobileDependency, connEstablishedChan chan struct{}, logPath string) error
// supervisor serializes start/stop of a single client run. Every request goes
// through cmdCh and is handled one at a time by the loop goroutine, so two
// lifecycle operations can never overlap and their order is preserved (FIFO).
// The loop goroutine is the sole owner of curStart/runCancel, so that state
// needs no locking. The loop exits when the parent context is cancelled.
type supervisor struct {
ctx context.Context
run runFunc
cmdCh chan lifecycleCmd
runEnded chan runEndResult
// owned exclusively by the loop goroutine. curStart is the in-flight start
// command (nil = idle); its done channel is notified when the run ends.
// curRun holds that run's lifecycle channels; runCancel cancels it.
curStart *lifecycleCmd
curRun *runState
runCancel context.CancelFunc
}
func newSupervisor(ctx context.Context, run runFunc) *supervisor {
s := &supervisor{
ctx: ctx,
run: run,
cmdCh: make(chan lifecycleCmd, 16),
runEnded: make(chan runEndResult, 1),
}
go s.loop()
return s
}
func (s *supervisor) loop() {
for {
select {
case <-s.ctx.Done():
s.shutdown()
return
case cmd := <-s.cmdCh:
switch cmd.op {
case opStart:
s.handleStart(cmd)
case opStop:
s.handleStop(cmd)
case opRestart:
s.handleRestart(cmd)
case opStatus:
cmd.reply <- (s.isRunningInternal())
case opWaitEstablished:
s.handleWaitEstablished(cmd)
}
case res := <-s.runEnded:
// Run ended on its own, without an explicit Stop.
s.finishRun(res.err)
}
}
}
func (s *supervisor) handleStart(cmd lifecycleCmd) {
if s.isRunningInternal() {
notify(cmd.done, errAlreadyRunning)
return
}
runCtx, cancel := context.WithCancel(s.ctx)
if cmd.md != nil {
// Carry caller-supplied gRPC metadata (e.g. UI user-agent) into the run
// context so the engine's management/signal calls forward it. The cancel
// still drives runCtx (metadata wrapping preserves cancellation).
runCtx = metadata.NewOutgoingContext(runCtx, cmd.md)
}
s.runCancel = cancel
s.curStart = &cmd
s.curRun = &runState{connEstablishedChan: make(chan struct{}), ended: make(chan struct{})}
go func(ctx context.Context, cfg *profilemanager.Config, m MobileDependency, established chan struct{}, lp string) {
err := s.run(ctx, cfg, m, established, lp)
s.runEnded <- runEndResult{err: err}
}(runCtx, cmd.config, cmd.mobileDep, s.curRun.connEstablishedChan, cmd.logPath)
}
func (s *supervisor) handleStop(cmd lifecycleCmd) {
if !s.isRunningInternal() {
notify(cmd.done, nil)
return
}
s.stopCurrentRun()
notify(cmd.done, nil)
}
// handleRestart tears down any in-flight run and starts a fresh one in a single
// loop turn. No other command can interleave between the stop and the start
// (the loop is single-threaded), so the swap is atomic without relying on any
// daemon-side lock — that is what an explicit restart (e.g. MDM config change)
// needs to avoid a window where the client is observably stopped.
func (s *supervisor) handleRestart(cmd lifecycleCmd) {
if s.isRunningInternal() {
s.stopCurrentRun()
}
s.handleStart(cmd)
}
// stopCurrentRun cancels the in-flight run and blocks the supervisor until it
// has fully unwound, so the next action starts from a clean slate. The run
// goroutine reports completion via runEnded. Caller must hold an in-flight run
// (curStart != nil).
func (s *supervisor) stopCurrentRun() {
s.runCancel()
res := <-s.runEnded
s.finishRun(res.err)
}
// finishRun resets lifecycle state after a run terminates and hands the run
// error back to whoever asked to be notified of the start.
func (s *supervisor) finishRun(err error) {
s.runCancel = nil
if s.isRunningInternal() {
// Publish the result to the broadcast channel before nil-ing curRun, so
// any opWaitEstablished goroutines blocked on ended observe err.
s.curRun.err = err
close(s.curRun.ended)
s.curRun = nil
notify(s.curStart.done, err)
s.curStart = nil
}
}
// handleWaitEstablished answers an opWaitEstablished request. The select itself
// runs in a spawned goroutine on the run's channels so it never blocks the loop;
// the loop only snapshots the in-flight run's channels (which it owns) here.
func (s *supervisor) handleWaitEstablished(cmd lifecycleCmd) {
caller := cmd.done
if !s.isRunningInternal() {
notify(caller, errNoRunInFlight)
return
}
rs := s.curRun
established := rs.connEstablishedChan
ctx := cmd.waitCtx
go func() {
select {
case <-established:
notify(caller, nil)
case <-rs.ended:
if rs.err != nil {
notify(caller, rs.err)
return
}
notify(caller, errStoppedBeforeEstablished)
case <-ctx.Done():
notify(caller, ctx.Err())
}
}()
}
// shutdown tears down the in-flight run when the parent context is cancelled,
// then fails any still-queued commands so their callers never hang.
func (s *supervisor) shutdown() {
if s.runCancel != nil {
s.runCancel()
res := <-s.runEnded
s.finishRun(res.err)
}
for {
select {
case cmd := <-s.cmdCh:
notify(cmd.done, s.ctx.Err())
default:
return
}
}
}
// startAsync enqueues a start without blocking. If done is non-nil it receives
// the run's end result (or errAlreadyRunning on rejection, or the context error
// on shutdown).
func (s *supervisor) startAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string, done chan error) {
cmd := lifecycleCmd{op: opStart, config: config, md: md, mobileDep: mobileDep, logPath: logPath, done: done}
select {
case s.cmdCh <- cmd:
case <-s.ctx.Done():
notify(done, s.ctx.Err())
}
}
// restartAsync enqueues an atomic stop+start without blocking. The supervisor
// tears down any in-flight run and starts a fresh one with the supplied config
// in a single loop turn (see handleRestart). Fire-and-forget: the new run owns
// its lifecycle channels, observed via waitEstablishedOrDone.
func (s *supervisor) restartAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) {
cmd := lifecycleCmd{op: opRestart, config: config, md: md, mobileDep: mobileDep, logPath: logPath}
select {
case s.cmdCh <- cmd:
case <-s.ctx.Done():
}
}
// start enqueues a start and blocks until the run terminates, preserving the
// blocking contract of the legacy Run entry points.
func (s *supervisor) start(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) error {
done := make(chan error, 1)
s.startAsync(config, md, mobileDep, logPath, done)
select {
case err := <-done:
return err
case <-s.ctx.Done():
return s.ctx.Err()
}
}
// isRunning asks the loop whether a run is in flight. The query is serialized
// with start/stop, so during a stop it waits for the teardown to settle and
// then reports the final state — never a transient "half-stopped".
func (s *supervisor) isRunning() bool {
reply := make(chan bool, 1)
select {
case s.cmdCh <- lifecycleCmd{op: opStatus, reply: reply}:
case <-s.ctx.Done():
return false
}
select {
case r := <-reply:
return r
case <-s.ctx.Done():
return false
}
}
func (s *supervisor) isRunningInternal() bool {
return s.curStart != nil
}
// waitEstablishedOrDone blocks until the in-flight run becomes established
// (returns nil) or ends before that (returns the run error, or
// errStoppedBeforeEstablished on a clean stop), or ctx is cancelled. Returns
// errNoRunInFlight if no run is in flight. The wait is performed by a goroutine
// spawned inside the loop (see handleWaitEstablished); the run's channels never
// leave the supervisor.
func (s *supervisor) waitEstablishedOrDone(ctx context.Context) error {
reply := make(chan error, 1)
select {
case s.cmdCh <- lifecycleCmd{op: opWaitEstablished, waitCtx: ctx, done: reply}:
case <-ctx.Done():
return ctx.Err()
case <-s.ctx.Done():
return s.ctx.Err()
}
select {
case err := <-reply:
return err
case <-s.ctx.Done():
return s.ctx.Err()
}
}
// stop enqueues a stop and blocks until the in-flight run is fully torn down.
func (s *supervisor) stop() error {
done := make(chan error, 1)
select {
case s.cmdCh <- lifecycleCmd{op: opStop, done: done}:
case <-s.ctx.Done():
return s.ctx.Err()
}
select {
case err := <-done:
return err
case <-s.ctx.Done():
return s.ctx.Err()
}
}
// notify sends on a caller-supplied channel without blocking. The channel is
// expected to be buffered (cap 1); a nil channel means the caller did not ask
// to be notified.
func notify(ch chan error, err error) {
if ch == nil {
return
}
select {
case ch <- err:
default:
}
}

View File

@@ -51,20 +51,13 @@ type cachedRecord struct {
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
// guarded by mutex.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct {
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// failedResolves records the last failed initial resolve per domain so a
// domain that never resolves isn't retried on every server-domains update
// until refreshBackoff elapses. Entries are cleared on success and pruned
// to the current server-domains set.
failedResolves map[domain.Domain]time.Time
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
@@ -83,10 +76,9 @@ type Resolver struct {
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
failedResolves: make(map[domain.Domain]time.Time),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
}
}
@@ -181,9 +173,7 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype. When one family hard-errors while the other succeeds,
// the resolved family is still cached but AddDomain returns an error so the
// caller retries the incomplete resolve rather than treating it as complete.
// entry for that qtype.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
@@ -213,10 +203,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
if errA != nil || errAAAA != nil {
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
return nil
}
@@ -476,7 +462,6 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
delete(m.failedResolves, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -520,7 +505,6 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
m.pruneFailedResolves(allDomains)
}
m.addNewDomains(ctx, newDomains)
@@ -593,85 +577,13 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches domains that are not yet in the cache,
// running the lookups concurrently. Domains already cached are skipped and left
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
// synchronously: once NetBird owns the OS resolver the resolve runs through the
// handler chain and would otherwise dial the managed upstreams under the engine
// sync lock on every update.
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
var wg sync.WaitGroup
seen := make(map[domain.Domain]struct{}, len(newDomains))
for _, newDomain := range newDomains {
if _, dup := seen[newDomain]; dup {
continue
}
seen[newDomain] = struct{}{}
if !m.needsResolve(newDomain) {
continue
}
wg.Add(1)
go func(d domain.Domain) {
defer wg.Done()
if err := m.AddDomain(ctx, d); err != nil {
m.markResolveFailed(d)
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return
}
m.clearResolveFailed(d)
log.Debugf("added/updated management cache domain=%s", d.SafeString())
}(newDomain)
}
wg.Wait()
}
// needsResolve reports whether d should be resolved now. A recent failed or
// incomplete resolve gates retries on the backoff even when one family is
// already cached, so a transiently-failed family is retried instead of being
// treated as fully resolved. Otherwise a domain with any cached record is left
// to the stale-while-revalidate refresh path.
func (m *Resolver) needsResolve(d domain.Domain) bool {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.RLock()
defer m.mutex.RUnlock()
if failedAt, ok := m.failedResolves[d]; ok {
return time.Since(failedAt) >= refreshBackoff
}
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return false
}
}
return true
}
func (m *Resolver) markResolveFailed(d domain.Domain) {
m.mutex.Lock()
m.failedResolves[d] = time.Now()
m.mutex.Unlock()
}
func (m *Resolver) clearResolveFailed(d domain.Domain) {
m.mutex.Lock()
delete(m.failedResolves, d)
m.mutex.Unlock()
}
// pruneFailedResolves drops failure markers for domains no longer present in
// the server-domains set, keeping the map bounded to the current set (a
// failed-only domain has no cached record, so RemoveDomain never sees it).
func (m *Resolver) pruneFailedResolves(domains domain.List) {
m.mutex.Lock()
defer m.mutex.Unlock()
for d := range m.failedResolves {
if !slices.Contains(domains, d) {
delete(m.failedResolves, d)
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}

View File

@@ -21,7 +21,6 @@ type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
qErr map[string]error
err error
hasRoot bool
onLookup func()
@@ -31,7 +30,6 @@ func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
qErr: map[string]error{},
hasRoot: true,
}
}
@@ -49,9 +47,6 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
f.calls[key]++
answers := f.answers[key]
err := f.err
if err == nil {
err = f.qErr[key]
}
onLookup := f.onLookup
f.mu.Unlock()
@@ -80,12 +75,6 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
}
}
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -1,183 +0,0 @@
package mgmt
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
// A domain already in the cache must not be re-resolved on a subsequent server
// domains update; it is left to the stale-while-revalidate refresh path.
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must resolve the domain")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"cached domain must not be re-resolved on a subsequent update")
}
// New domains in a single update must resolve concurrently rather than serially.
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
var inflight, maxInflight atomic.Int32
chain.onLookup = func() {
n := inflight.Add(1)
for {
old := maxInflight.Load()
if n <= old || maxInflight.CompareAndSwap(old, n) {
break
}
}
time.Sleep(50 * time.Millisecond)
inflight.Add(-1)
}
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
for _, d := range relays {
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
}
r.SetChainResolver(chain, 50)
start := time.Now()
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
require.NoError(t, err)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
}
// A domain that fails to resolve must not be retried on every update; the
// failure backoff suppresses re-resolution until it expires.
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must attempt the resolve")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"failed resolve must back off and not retry on the next update")
}
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
// the same host) must be resolved once per update, not once per occurrence.
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{
Stuns: []domain.Domain{"dup.example.com"},
Turns: []domain.Domain{"dup.example.com"},
}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
"a domain appearing under multiple server-domain types must resolve once")
}
// A failure marker must be dropped once its domain leaves the server-domains set
// so the map stays bounded to the current set.
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
require.True(t, marked, "failed resolve must be recorded")
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
}
// When one family hard-errors while the other resolves, the domain is cached
// for the working family but recorded as incomplete so the failed family is
// retried under backoff instead of being treated as fully resolved forever.
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
d := domain.Domain("relay.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
require.True(t, aCached, "the working family must still be cached")
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
r.mutex.Lock()
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
r.mutex.Unlock()
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
}
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
// not a failure: the domain must not be marked for retry, otherwise it would be
// re-resolved on every sync.
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
d := domain.Domain("v4only.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
}

View File

@@ -207,35 +207,3 @@ func FormatAnswers(answers []dns.RR) string {
}
return "[" + strings.Join(parts, ", ") + "]"
}
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
// RFC 6891 a responder must not include an OPT RR toward a client that did not
// advertise EDNS0.
func StripOPT(msg *dns.Msg) {
if len(msg.Extra) == 0 {
return
}
out := msg.Extra[:0]
for _, rr := range msg.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
msg.Extra = out
}
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
// the message, if present.
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
opt := msg.IsEdns0()
if opt == nil {
return nil, false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok {
return ede, true
}
}
return nil, false
}

View File

@@ -120,42 +120,3 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
StripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestExtractEDE(t *testing.T) {
t.Run("no edns", func(t *testing.T) {
_, ok := ExtractEDE(&dns.Msg{})
assert.False(t, ok, "message without OPT has no EDE")
})
t.Run("edns without ede", func(t *testing.T) {
rm := &dns.Msg{}
rm.SetEdns0(4096, false)
_, ok := ExtractEDE(rm)
assert.False(t, ok, "OPT without EDE option returns false")
})
t.Run("with ede", func(t *testing.T) {
rm := &dns.Msg{}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
rm.Extra = append(rm.Extra, opt)
ede, ok := ExtractEDE(rm)
assert.True(t, ok, "EDE option should be found")
assert.Equal(t, uint16(49152), ede.InfoCode)
assert.Equal(t, "upstream timeout", ede.ExtraText)
})
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
@@ -39,15 +38,11 @@ const (
// defaultWarningDelayBase is the starting grace window before a
// "Nameserver group unreachable" event fires for a group that's
// never been healthy and only has overlay upstreams with no
// Connected peer. Per-server and overridable via envWarningDelay;
// see warningDelay.
defaultWarningDelayBase = 60 * time.Second
// Connected peer. Per-server and overridable; see warningDelayFor.
defaultWarningDelayBase = 30 * time.Second
// warningDelayBonusCap caps the route-count bonus added to the
// base grace window. See warningDelay.
// base grace window. See warningDelayFor.
warningDelayBonusCap = 30 * time.Second
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
)
// errNoUsableNameservers signals that a merged-domain group has no usable
@@ -140,7 +135,7 @@ type DefaultServer struct {
disableSys bool
mux sync.Mutex
service service
dnsMuxHandlers []handlerWrapper
dnsMuxMap registeredHandlerMap
localResolver *local.Resolver
wgInterface WGIface
hostManager hostManager
@@ -204,6 +199,8 @@ type handlerWrapper struct {
priority int
}
type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
WgInterface WGIface
@@ -292,6 +289,7 @@ func newDefaultServer(
service: dnsService,
handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(),
wgInterface: wgInterface,
statusRecorder: statusRecorder,
@@ -300,7 +298,7 @@ func newDefaultServer(
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
warningDelayBase: warningDelayBaseFromEnv(),
warningDelayBase: defaultWarningDelayBase,
healthRefresh: make(chan struct{}, 1),
}
// Wire the local resolver against the peer status recorder so it can
@@ -330,7 +328,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
type routeSettable interface {
setSelectedRoutes(func() route.HAMap)
}
for _, entry := range s.dnsMuxHandlers {
for _, entry := range s.dnsMuxMap {
if h, ok := entry.handler.(routeSettable); ok {
h.setSelectedRoutes(selected)
}
@@ -980,23 +978,19 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxHandlers {
for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority)
// The local resolver is a persistent singleton shared by every custom
// zone and reused across config updates. Its chain registrations are
// per-config and must be deregistered, but Stop() cancels its lookup
// context (breaking external CNAME-target resolution) and clears its
// records, so it must not be torn down here.
if existing.handler != s.localResolver {
existing.handler.Stop()
}
existing.handler.Stop()
}
muxUpdateMap := make(registeredHandlerMap)
for _, update := range muxUpdates {
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update
}
s.dnsMuxHandlers = muxUpdates
s.dnsMuxMap = muxUpdateMap
}
// updateNSGroupStates records the new group set and pokes the refresher.
@@ -1160,26 +1154,6 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
return false
}
// warningDelayBaseFromEnv returns the base grace window, honoring
// envWarningDelay when it holds a valid positive Go duration. Invalid or
// non-positive values fall back to defaultWarningDelayBase.
func warningDelayBaseFromEnv() time.Duration {
val := os.Getenv(envWarningDelay)
if val == "" {
return defaultWarningDelayBase
}
d, err := time.ParseDuration(val)
if err != nil {
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
return defaultWarningDelayBase
}
if d <= 0 {
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
return defaultWarningDelayBase
}
return d
}
// warningDelay returns the grace window for the given selected-route
// count. Scales gently: +1s per 100 routes, capped by
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
@@ -1230,7 +1204,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
// in more than one handler.
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
merged := make(map[netip.AddrPort]UpstreamHealth)
for _, entry := range s.dnsMuxHandlers {
for _, entry := range s.dnsMuxMap {
reporter, ok := entry.handler.(upstreamHealthReporter)
if !ok {
continue

View File

@@ -104,6 +104,19 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
@@ -119,20 +132,22 @@ func TestUpdateDNSServer(t *testing.T) {
},
}
dummyHandler := local.NewResolver()
testCases := []struct {
name string
initUpstreamMap []handlerWrapper
initUpstreamMap registeredHandlerMap
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap []handlerWrapper
expectedUpstreamMap registeredHandlerMap
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: nil,
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -154,17 +169,20 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
{
dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
{
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault,
},
},
@@ -173,10 +191,10 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
handler: &mockHandler{},
handler: dummyHandler,
priority: PriorityUpstream,
},
},
@@ -197,13 +215,15 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
{
"local-resolver": handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
},
@@ -212,7 +232,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
@@ -220,7 +240,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -242,7 +262,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -264,7 +284,7 @@ func TestUpdateDNSServer(t *testing.T) {
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
@@ -281,41 +301,42 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: []handlerWrapper{{
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: ".",
handler: dummyHandler,
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
handler: dummyHandler,
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: nil,
expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
handler: dummyHandler,
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: nil,
expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalQs: []dns.Question{},
},
}
@@ -372,7 +393,7 @@ func TestUpdateDNSServer(t *testing.T) {
}
}()
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
@@ -384,20 +405,14 @@ func TestUpdateDNSServer(t *testing.T) {
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
}
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
for key := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key]
if !found {
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
}
}
@@ -497,8 +512,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
}()
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
@@ -1014,15 +1029,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
func (m *mockService) DeregisterMux(string) {}
func TestDefaultServer_UpdateMux(t *testing.T) {
baseMatchHandlers := []handlerWrapper{
{
baseMatchHandlers := registeredHandlerMap{
"upstream-group1": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
},
{
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
@@ -1031,15 +1046,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
},
}
baseRootHandlers := []handlerWrapper{
{
baseRootHandlers := registeredHandlerMap{
"upstream-root1": {
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
{
"upstream-root2": {
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
@@ -1048,22 +1063,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
},
}
baseMixedHandlers := []handlerWrapper{
{
baseMixedHandlers := registeredHandlerMap{
"upstream-group1": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
},
{
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
},
{
"upstream-other": {
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
@@ -1074,7 +1089,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
tests := []struct {
name string
initialHandlers []handlerWrapper
initialHandlers registeredHandlerMap
updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain
description string
@@ -1358,38 +1373,32 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &DefaultServer{
dnsMuxHandlers: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
dnsMuxMap: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Perform the update
server.updateMux(tt.updates)
// Verify the results
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
"Number of handlers after update doesn't match expected")
// Check each expected handler
for id, expectedDomain := range tt.expectedHandlers {
var found *handlerWrapper
for i := range server.dnsMuxHandlers {
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
found = &server.dnsMuxHandlers[i]
break
}
}
assert.NotNil(t, found, "Expected handler %s not found", id)
if found != nil {
assert.Equal(t, expectedDomain, found.domain,
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id)
if exists {
assert.Equal(t, expectedDomain, handler.domain,
"Domain mismatch for handler %s", id)
}
}
// Verify no unexpected handlers exist
for _, entry := range server.dnsMuxHandlers {
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
for HandlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(HandlerID)]
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
}
// Verify the handlerChain state and order
@@ -1404,7 +1413,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Verify handler exists in mux
foundInMux := false
for _, muxEntry := range server.dnsMuxHandlers {
for _, muxEntry := range server.dnsMuxMap {
if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
@@ -1413,108 +1422,12 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}
}
assert.True(t, foundInMux,
"Handler in chain not found in dnsMuxHandlers")
"Handler in chain not found in dnsMuxMap")
}
})
}
}
// chainHasPattern reports whether the handler chain holds an entry registered
// for the given fqdn pattern at the given priority.
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
for _, h := range s.handlerChain.handlers {
if h.OrigPattern == pattern && h.Priority == priority {
return true
}
}
return false
}
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
// tracks each (handler, domain) registration independently when one handler
// serves multiple zones. Every custom zone is served by the same handler
// instance (the local resolver, whose ID is the constant "local-resolver"), so
// removing one zone must deregister exactly that zone's chain entry and leave
// the others in place. Tracking registrations by handler ID alone collapses all
// zones onto one entry, leaving removed zones in the chain to answer
// authoritatively with no records.
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
// One handler serves every custom zone, mirroring s.localResolver.
shared := &mockHandler{Id: "local-resolver"}
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Two custom zones under the same handler. The surviving zone is registered
// last, mirroring the management emission order.
server.updateMux([]handlerWrapper{
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test should be registered after the first update")
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should be registered after the first update")
// Remove one zone, keep the other.
server.updateMux([]handlerWrapper{
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should remain after removing userzone.test")
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test handler must be deregistered, not leaked in the chain")
}
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
// does not tear down the shared local resolver during reconfiguration. The
// resolver is a process-lifetime singleton reused across config updates;
// Stop() cancels its lookup context (breaking external CNAME-target
// resolution) and clears its records. updateMux must deregister its chain
// entries without stopping it. Records surviving a teardown update is the
// observable proxy: Stop() would have cleared them.
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
resolver := local.NewResolver()
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
Name: "peer.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}))
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
localResolver: resolver,
}
server.updateMux([]handlerWrapper{
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
})
// Remove the zone. The resolver must survive so its records and lookup
// context stay intact for the next registration.
server.updateMux(nil)
var response *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
response = m
return nil
},
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
require.NotNil(t, response, "local resolver should answer after teardown")
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
}
func TestExtraDomains(t *testing.T) {
tests := []struct {
name string
@@ -2136,6 +2049,7 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
groups := []*nbdns.NameServerGroup{
@@ -2293,7 +2207,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
}
}
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers.
type healthStubHandler struct {
@@ -2369,11 +2283,12 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase,
}
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
@@ -2480,6 +2395,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond,
@@ -2491,7 +2407,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2528,6 +2444,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
@@ -2542,7 +2459,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2567,32 +2484,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
// rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either.
func TestWarningDelayBaseFromEnv(t *testing.T) {
tests := []struct {
name string
set bool
val string
want time.Duration
}{
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envWarningDelay, tc.val)
if !tc.set {
os.Unsetenv(envWarningDelay)
}
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
})
}
}
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond
@@ -2704,6 +2595,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour,
@@ -2721,7 +2613,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
overlay: {LastFail: time.Now(), LastErr: "timeout"},
},
}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2748,6 +2640,7 @@ func TestDNSLoopPrevention(t *testing.T) {
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
tests := []struct {

View File

@@ -443,32 +443,29 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
}
// A valid response means the upstream is reachable, whatever the Rcode.
u.markUpstreamOk(upstream)
proto := ""
if upstreamProto != nil {
proto = upstreamProto.protocol
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
// refused zones, transient recursion errors), not reachability
// problems: fail over for a better answer but keep the upstream healthy.
if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns {
resutil.StripOPT(rm)
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
}
reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
}
if !hadEdns {
resutil.StripOPT(rm)
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
}
@@ -523,6 +520,22 @@ func upstreamUDPSize() uint16 {
return dns.MinMsgSize
}
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()}

View File

@@ -517,78 +517,6 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
}
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
// those are per-question outcomes from a reachable server and must not mark
// the upstream unhealthy. Only transport failures (timeouts) do.
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.10:53")
b := netip.MustParseAddrPort("192.0.2.11:53")
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
tests := []struct {
name string
respA mockUpstreamResponse
respB mockUpstreamResponse
wantHealthy bool
}{
{
name: "both SERVFAIL are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
wantHealthy: true,
},
{
name: "both REFUSED are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
wantHealthy: true,
},
{
name: "timeout marks unhealthy",
respA: mockUpstreamResponse{err: timeoutErr},
respB: mockUpstreamResponse{err: timeoutErr},
wantHealthy: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): tc.respA,
b.String(): tc.respB,
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a, b})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, a, "primary upstream should have a health record")
if tc.wantHealthy {
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
} else {
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
}
})
}
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string
@@ -985,6 +913,19 @@ func TestEDEName(t *testing.T) {
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")

View File

@@ -26,15 +26,6 @@ import (
const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second
// EDE info codes the forwarder emits on upstream failures so the querying
// client can see the reason without inspecting this peer's logs. They live in
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
// real upstream EDE here, so these cannot collide with a genuine code.
const (
edeNetbirdUpstreamTimeout uint16 = 49152
edeNetbirdUpstreamFailure uint16 = 49153
)
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
@@ -229,7 +220,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return
}
@@ -342,7 +333,6 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg,
domain string,
result resutil.LookupResult,
reqHasEdns bool,
startTime time.Time,
) {
qType := question.Qtype
@@ -384,10 +374,6 @@ func (f *DNSForwarder) handleDNSError(
logger.Warnf(errResolveFailed, domain, result.Err)
}
if reqHasEdns {
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
}
f.writeResponse(logger, w, resp, domain, startTime)
}
@@ -428,33 +414,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches
}
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
func edeCodeFor(dnsErr *net.DNSError) uint16 {
if dnsErr != nil && dnsErr.IsTimeout {
return edeNetbirdUpstreamTimeout
}
return edeNetbirdUpstreamFailure
}
// edeText builds the EDE extra-text describing the class of upstream failure.
// It deliberately omits the upstream server address, which may be an internal
// resolver and is exposed to any client permitted to use the route; the full
// detail stays in the forwarder's local log.
func edeText(dnsErr *net.DNSError) string {
if dnsErr != nil && dnsErr.IsTimeout {
return "netbird forwarder: upstream timeout"
}
return "netbird forwarder: upstream failure"
}
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
// creating the OPT pseudo-record if the response does not already carry one.
func attachEDE(resp *dns.Msg, code uint16, text string) {
opt := resp.IsEdns0()
if opt == nil {
resp.SetEdns0(dns.DefaultMsgSize, false)
opt = resp.IsEdns0()
}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
@@ -618,85 +617,6 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
}
}
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string
lookupErr error
reqEdns bool
wantEDE bool
wantCode uint16
wantTextHas string
}{
{
name: "timeout with edns0",
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamTimeout,
wantTextHas: "netbird forwarder: upstream timeout",
},
{
name: "server failure with edns0",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamFailure,
wantTextHas: "netbird forwarder: upstream failure",
},
{
name: "no edns0 in request omits ede",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: false,
wantEDE: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr(nil), tt.lookupErr).Once()
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
if tt.reqEdns {
query.SetEdns0(dns.DefaultMsgSize, false)
}
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
mockResolver.AssertExpectations(t)
require.NotNil(t, writtenResp, "expected a response")
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
ede, ok := resutil.ExtractEDE(writtenResp)
if !tt.wantEDE {
assert.False(t, ok, "response must not carry EDE")
return
}
require.True(t, ok, "response must carry EDE")
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}

View File

@@ -22,6 +22,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
@@ -86,8 +88,6 @@ const (
var ErrResetConnection = fmt.Errorf("reset connection")
var ErrEngineAlreadyStarted = errors.New("engine already started")
type EngineConfig struct {
WgPort int
WgIfaceName string
@@ -201,8 +201,6 @@ type Engine struct {
ctx context.Context
cancel context.CancelFunc
started bool
wgInterface WGIface
udpMux *udpmux.UniversalUDPMuxDefault
@@ -283,15 +281,9 @@ func NewEngine(
services EngineServices,
mobileDep MobileDependency,
) *Engine {
// The engine is single-use: a fresh instance is built per connection
// cycle (see Client.run), so the run context is created once here rather
// than in Start.
ctx, cancel := context.WithCancel(clientCtx)
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
ctx: ctx,
cancel: cancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
@@ -324,34 +316,8 @@ func (e *Engine) Stop() error {
log.Debugf("tried stopping engine that is nil")
return nil
}
e.cancel()
e.syncMsgMux.Lock()
e.stopLocked()
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// stopLocked tears down everything Start may have brought up, in the order
// teardown requires (DNS before the interface goes down, flow manager after).
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
// path, so a partially-initialized engine is cleaned up the same way; every
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
// after releasing the lock, since the goroutines also take syncMsgMux.
func (e *Engine) stopLocked() {
if e.connMgr != nil {
e.connMgr.Close()
}
@@ -402,6 +368,10 @@ func (e *Engine) stopLocked() {
// so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.cancel != nil {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish
e.close()
@@ -420,6 +390,21 @@ func (e *Engine) stopLocked() {
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
@@ -457,38 +442,18 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// The engine is single-use. Reject a duplicate start and a start on an
// already-stopped engine (run context cancelled).
if e.started {
return ErrEngineAlreadyStarted
}
if ctxErr := e.ctx.Err(); ctxErr != nil {
return fmt.Errorf("engine already stopped: %w", ctxErr)
}
e.started = true
// Tear down any partially-initialized state on a failed start. Cancel the
// run context first so goroutines started before the failure (connMgr,
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
// not just what close() covers.
defer func() {
if err != nil {
e.cancel()
e.stopLocked()
}
}()
if err = iface.ValidateMTU(e.config.MTU); err != nil {
if err := iface.ValidateMTU(e.config.MTU); err != nil {
return fmt.Errorf("invalid MTU configuration: %w", err)
}
if e.cancel != nil {
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface()
@@ -522,11 +487,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil {
e.close()
return fmt.Errorf("read initial settings: %w", err)
}
dnsServer, err := e.newDnsServer(dnsConfig)
if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err)
}
e.dnsServer = dnsServer
@@ -561,6 +528,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
return fmt.Errorf("create wg interface: %w", err)
}
@@ -569,6 +537,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
}
@@ -580,6 +549,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.udpMux, err = e.wgInterface.Up()
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
e.close()
return fmt.Errorf("up wg interface: %w", err)
}
@@ -604,7 +574,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.acl = acl.NewDefaultManager(e.firewall)
}
if err := e.dnsServer.Initialize(); err != nil {
err = e.dnsServer.Initialize()
if err != nil {
e.close()
return fmt.Errorf("initialize dns server: %w", err)
}
@@ -616,9 +588,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start(peer.IsForceRelayed())
if err = e.receiveSignalEvents(); err != nil {
return err
}
e.receiveSignalEvents()
e.receiveManagementEvents()
e.receiveJobEvents()
@@ -670,6 +640,7 @@ func (e *Engine) createFirewall() error {
func (e *Engine) initFirewall() error {
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
e.close()
return fmt.Errorf("set firewall: %w", err)
}
@@ -1066,7 +1037,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
@@ -1097,20 +1068,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
return nil
}
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
// can be excluded from the reported network addresses; the interface coming and
// going otherwise churns the peer meta on the management server.
func (e *Engine) overlayAddresses() []netip.Addr {
var ips []netip.Addr
if e.config.WgAddr.IP.IsValid() {
ips = append(ips, e.config.WgAddr.IP)
}
if e.config.WgAddr.HasIPv6() {
ips = append(ips, e.config.WgAddr.IPv6)
}
return ips
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
@@ -1170,6 +1127,20 @@ func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
}
// wrapDisconnectError classifies a receive-loop failure before the run is torn
// down. An auth rejection (PermissionDenied/Unauthenticated) means the session
// needs re-login and retrying is futile, so mark it terminal (NeedsLogin) — run()
// then exits on its own instead of spinning the backoff. Any other failure is a
// recoverable connection reset that the backoff should retry.
func (e *Engine) wrapDisconnectError(err error) {
state := CtxGetState(e.ctx)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied || s.Code() == codes.Unauthenticated) {
state.Set(StatusNeedsLogin)
return
}
_ = state.Wrap(ErrResetConnection)
}
func (e *Engine) receiveJobEvents() {
e.jobExecutorWG.Add(1)
go func() {
@@ -1196,9 +1167,9 @@ func (e *Engine) receiveJobEvents() {
}
})
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
// happens if management is unavailable for a long time, or rejects
// us (auth). wrapDisconnectError decides retry vs needs-login.
e.wrapDisconnectError(err)
e.clientCancel()
return
}
@@ -1254,7 +1225,7 @@ func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
@@ -1280,9 +1251,9 @@ func (e *Engine) receiveManagementEvents() {
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
// happens if management is unavailable for a long time, or rejects
// us (auth). wrapDisconnectError decides retry vs needs-login.
e.wrapDisconnectError(err)
e.clientCancel()
return
}
@@ -1743,7 +1714,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() error {
func (e *Engine) receiveSignalEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
@@ -1806,20 +1777,15 @@ func (e *Engine) receiveSignalEvents() error {
return nil
})
if err != nil {
// happens if signal is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
// happens if signal is unavailable for a long time, or rejects us
// (auth). wrapDisconnectError decides retry vs needs-login.
e.wrapDisconnectError(err)
e.clientCancel()
return
}
}()
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
e.signal.WaitStreamConnected(e.ctx)
if err := e.ctx.Err(); err != nil {
return fmt.Errorf("wait for signal stream: %w", err)
}
return nil
e.signal.WaitStreamConnected()
}
func (e *Engine) parseNATExternalIPMappings() []string {

View File

@@ -247,7 +247,7 @@ func TestEngine_SSH(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
@@ -426,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// feed updates to Engine via mocked Management client
@@ -817,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
@@ -1024,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)

View File

@@ -119,6 +119,10 @@ func (d *BindListener) ReadPackets() {
}
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.lazyConn.Close()
d.bind.RemoveEndpoint(d.fakeIP)
d.done.Done()

View File

@@ -195,14 +195,14 @@ func (h *Handshaker) sendOffer() error {
}
offer := h.buildOfferAnswer()
h.log.Debugf("sending offer with serial: %s", offer.SessionIDString())
h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
return h.signaler.SignalOffer(offer, h.config.Key)
}
func (h *Handshaker) sendAnswer() error {
answer := h.buildOfferAnswer()
h.log.Debugf("sending answer with serial: %s", answer.SessionIDString())
h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
return h.signaler.SignalAnswer(answer, h.config.Key)
}

View File

@@ -192,7 +192,6 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
// Pure read methods take RLock; anything that mutates state takes Lock.
type Status struct {
mux sync.RWMutex
muxRelays sync.RWMutex
peers map[string]State
ipToKey map[string]string
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
@@ -245,8 +244,8 @@ func NewRecorder(mgmAddress string) *Status {
}
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
d.muxRelays.Lock()
defer d.muxRelays.Unlock()
d.mux.Lock()
defer d.mux.Unlock()
d.relayMgr = manager
}
@@ -907,8 +906,8 @@ func (d *Status) MarkSignalConnected() {
}
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
d.muxRelays.Lock()
defer d.muxRelays.Unlock()
d.mux.Lock()
defer d.mux.Unlock()
d.relayStates = relayResults
}
@@ -1019,26 +1018,24 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.muxRelays.RLock()
d.mux.RLock()
defer d.mux.RUnlock()
if d.relayMgr == nil {
defer d.muxRelays.RUnlock()
return slices.Clone(d.relayStates)
return d.relayStates
}
relayMgr := d.relayMgr
// extend the list of stun, turn servers with the relay server connections
relayStates := slices.Clone(d.relayStates)
d.muxRelays.RUnlock()
states := relayMgr.RelayStates()
states := d.relayMgr.RelayStates()
if len(states) == 0 {
// no relay connection tracked yet; surface configured servers as
// unavailable with the real reconnect error when known
err := relayClient.ErrRelayClientNotConnected
if connErr := relayMgr.RelayConnectError(); connErr != nil {
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
err = connErr
}
for _, r := range relayMgr.ServerURLs() {
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
Err: err,

View File

@@ -433,7 +433,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerSSHAllowed != nil && (config.ServerSSHAllowed == nil || *input.ServerSSHAllowed != *config.ServerSSHAllowed) {
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
if *input.ServerSSHAllowed {
log.Infof("enabling SSH server")
} else {

View File

@@ -242,35 +242,6 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
}
}
func TestUpdateConfigServerSSHAllowedNotSet(t *testing.T) {
// Configs written before ServerSSHAllowed was introduced lack the field and
// unmarshal to nil. Supplying the SSH server flag on top of such a config must
// apply the value instead of panicking on a nil pointer dereference.
tests := []struct {
name string
input *bool
want bool
}{
{"enable", util.True(), true},
{"disable", util.False(), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0600))
config, err := UpdateConfig(ConfigInput{
ConfigPath: configPath,
ServerSSHAllowed: tt.input,
})
require.NoError(t, err)
require.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set from input")
assert.Equal(t, tt.want, *config.ServerSSHAllowed)
})
}
}
func TestUpdateOldManagementURL(t *testing.T) {
origProber := newMgmProber
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {

View File

@@ -251,14 +251,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true
}
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
// describing why a lookup failed. The OPT is stripped from the reply when
// the original client did not request EDNS0.
hadEdns := r.IsEdns0() != nil
if !hadEdns {
r.SetEdns0(dns.DefaultMsgSize, false)
}
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
@@ -268,13 +260,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if ede, ok := resutil.ExtractEDE(reply); ok {
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
}
if !hadEdns {
resutil.StripOPT(reply)
}
resutil.SetMeta(w, "peer", peerKey)
reply.Id = r.Id

View File

@@ -171,13 +171,13 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
connectClient := internal.NewConnectClient(ctx, c.recorder)
c.setState(cfg, connectClient)
// Persist the latest sync response so DebugBundle can include the network
// map. On iOS this is backed by disk to keep it out of the constrained
// process memory (see the syncstore package).
connectClient.SetSyncResponsePersistence(true)
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
return connectClient.RunOniOS(cfg, fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
}
// Stop the internal client and free the resources

View File

@@ -36,7 +36,6 @@ type URLOpener interface {
// Auth can register or login new client
type Auth struct {
ctx context.Context
cancel context.CancelFunc
config *profilemanager.Config
cfgPath string
}
@@ -52,19 +51,8 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
return nil, err
}
// Use a cancellable context so Stop() can abort an in-progress interactive
// login. The PKCE flow's WaitToken blocks (and keeps its loopback HTTP server
// bound to a port) until the OAuth callback arrives or the flow expires;
// cancelling the context unblocks WaitToken, which then shuts that server down
// and frees the port for the next login attempt. iOS runs login in the main-app
// process (decoupled from the network extension), so without this the server
// lingers after the user dismisses the browser and the next connect stalls
// trying to bind the same port.
ctx, cancel := context.WithCancel(context.Background())
return &Auth{
ctx: ctx,
cancel: cancel,
ctx: context.Background(),
config: cfg,
cfgPath: cfgPath,
}, nil
@@ -72,24 +60,12 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
// NewAuthWithConfig instantiate Auth based on existing config
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
ctx, cancel := context.WithCancel(ctx)
return &Auth{
ctx: ctx,
cancel: cancel,
config: config,
}
}
// Stop aborts an in-progress interactive login started via Login/LoginWithDeviceName.
// It cancels the auth context, which unblocks the PKCE WaitToken and shuts down its
// loopback HTTP server, freeing the redirect port. Safe to call multiple times and
// safe to call when no login is running.
func (a *Auth) Stop() {
if a.cancel != nil {
a.cancel()
}
}
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
// is not supported and returns false without saving the configuration. For other errors return false.

View File

@@ -344,9 +344,6 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
}
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
if s.connectClient == nil {
return nil, status.Error(codes.FailedPrecondition, "client not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, status.Error(codes.FailedPrecondition, "engine not initialized")

View File

@@ -5,7 +5,6 @@ package server
import (
"bytes"
"context"
"errors"
"fmt"
"runtime/pprof"
@@ -28,11 +27,9 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
}
var clientMetrics debug.MetricsExporter
if s.connectClient != nil {
if engine := s.connectClient.Engine(); engine != nil {
if cm := engine.GetClientMetrics(); cm != nil {
clientMetrics = cm
}
if engine := s.connectClient.Engine(); engine != nil {
if cm := engine.GetClientMetrics(); cm != nil {
clientMetrics = cm
}
}
@@ -48,13 +45,10 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
defer s.cleanupBundleCapture()
var refreshStatus func()
if s.connectClient != nil {
engine := s.connectClient.Engine()
if engine != nil {
refreshStatus = func() {
log.Debug("refreshing system health status for debug bundle")
engine.RunHealthProbes(true)
}
if engine := s.connectClient.Engine(); engine != nil {
refreshStatus = func() {
log.Debug("refreshing system health status for debug bundle")
engine.RunHealthProbes(true)
}
}
@@ -118,9 +112,7 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
log.SetLevel(level)
if s.connectClient != nil {
s.connectClient.SetLogLevel(level)
}
s.connectClient.SetLogLevel(level)
log.Infof("Log level set to %s", level.String())
@@ -134,20 +126,13 @@ func (s *Server) SetSyncResponsePersistence(_ context.Context, req *proto.SetSyn
enabled := req.GetEnabled()
s.persistSyncResponse = enabled
if s.connectClient != nil {
s.connectClient.SetSyncResponsePersistence(enabled)
}
s.connectClient.SetSyncResponsePersistence(enabled)
return &proto.SetSyncResponsePersistenceResponse{}, nil
}
func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
cClient := s.connectClient
if cClient == nil {
return nil, errors.New("connect client is not initialized")
}
return cClient.GetLatestSyncResponse()
return s.connectClient.GetLatestSyncResponse()
}
// StartCPUProfile starts CPU profiling in the daemon.

View File

@@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
@@ -39,12 +38,11 @@ type conflictCheck struct {
// OS-native managed-config store reports a diff vs the last observation.
//
// Restart sequence:
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
// 1. Stop the in-flight run via the supervisor (blocks until fully torn down).
// 2. Re-resolve Config from disk + MDM policy (Config.apply re-runs
// applyMDMPolicy with the freshly loaded Policy).
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
// 3. Start a fresh run with the new config.
// 4. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
// RPC) can refresh its cached config view without polling.
//
// The callback runs in the ticker's own goroutine. Ticker has already
@@ -52,39 +50,24 @@ type conflictCheck struct {
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
log.Warn("MDM policy changed; restarting engine to apply new configuration")
// Hold s.mutex for the entire restart sequence (cancel + quiescence
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
// MDM is restarting blocks on the Lock until we are done — they
// then observe the post-restart state coherently. This is safe
// because the connectWithRetryRuns goroutine no longer acquires
// s.mutex in its defer (intent vs. goroutine-alive concerns are
// fully separated; see the connectionGoroutineRunning helper).
// Hold s.mutex for the entire restart sequence (stop + re-start). Any
// concurrent Up/Down/Status arriving while MDM is restarting blocks on the
// Lock until we are done — they then observe the post-restart state coherently.
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.clientRunning {
// The client is not running, so there's no engine to restart.
if !s.connectClient.ConnectionRunning() {
// No run in flight, so there's no engine to restart.
return nil
}
// Cancel daemon-side login/status activities tied to the old run; the run
// itself is torn down atomically by the supervisor inside Restart (see
// restartEngineForMDMLocked), which stops and re-starts in one operation.
if s.actCancel != nil {
s.actCancel()
}
// Wait for previous connectWithRetryRuns to exit so we don't end up
// with two goroutines fighting over the same status recorder + engine.
// The teardown engages a fan-out of engine goroutines (peer workers,
// signal handler, route manager, ...). close(clientGiveUpChan)
// happens in the function-scope defer of connectWithRetryRuns, on
// every exit path (ctx cancel, backoff exhausted, panic) — see the
// defer in server.go.
if s.clientGiveUpChan != nil {
select {
case <-s.clientGiveUpChan:
case <-time.After(10 * time.Second):
return fmt.Errorf("failed to restart the engine due to timeout")
}
}
if err := s.restartEngineForMDMLocked(); err != nil {
log.Errorf("MDM restart failed: %v", err)
return err
@@ -131,14 +114,13 @@ func (s *Server) publishConfigChangedEvent(source string) {
}
// restartEngineForMDMLocked re-resolves the active profile config
// (re-running applyMDMPolicy via Config.apply) and re-spawns
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
// MDM change behaves identically to a fresh boot under the new policy.
// (re-running applyMDMPolicy via Config.apply) and starts a fresh run.
// Mirrors the tail of Server.Start so a runtime MDM change behaves
// identically to a fresh boot under the new policy.
//
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
// state.
// for the entire restart sequence so concurrent Up/Down/Status RPCs
// observe a coherent post-restart state.
func (s *Server) restartEngineForMDMLocked() error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil {
@@ -154,13 +136,13 @@ func (s *Server) restartEngineForMDMLocked() error {
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
ctx, cancel := context.WithCancel(s.rootCtx)
_, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
log.Info("MDM restart: atomically restarting the run with re-resolved config")
// MDM restart has no incoming RPC metadata; fire and forget. Restart is a
// single supervisor op (atomic stop+start), so there is no observable
// "stopped" window between tearing down the old run and starting the new.
s.connectClient.Restart(config, nil)
s.publishConfigChangedEvent("mdm")
return nil
}

View File

@@ -34,10 +34,6 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
@@ -147,10 +143,6 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
@@ -199,10 +191,6 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")

View File

@@ -8,12 +8,10 @@ import (
"os"
"os/exec"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
@@ -39,15 +37,7 @@ import (
)
const (
probeThreshold = time.Second * 5
retryInitialIntervalVar = "NB_CONN_RETRY_INTERVAL_TIME"
maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME"
maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME"
retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER"
defaultInitialRetryTime = 30 * time.Minute
defaultMaxRetryInterval = 60 * time.Minute
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
probeThreshold = time.Second * 5
// JWT token cache TTL for the client daemon (disabled by default)
defaultJWTCacheTTL = 0
@@ -72,15 +62,8 @@ type Server struct {
mutex sync.Mutex
config *profilemanager.Config
proto.UnimplementedDaemonServiceServer
// clientRunning tracks "the daemon wants to be connected" — set true by
// Start / Up, cleared by Down / Logout. Persists across retry
// loops, signal disconnects, and ErrResetConnection cycles. NOT
// changed by connectWithRetryRuns goroutine exit — for that
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
// derives from clientGiveUpChan close state. Protected by s.mutex.
clientRunning bool
clientRunningChan chan struct{}
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
// Run state (in-flight? established/done channels?) is owned entirely by the
// supervisor inside connectClient — the daemon keeps no per-run fields.
connectClient *internal.ConnectClient
@@ -136,6 +119,13 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
networksDisabled: networksDisabled,
jwtCache: newJWTCache(),
}
// The ConnectClient is daemon-lifetime: build it exactly once, here. Its
// supervisor lives as long as the daemon; Up/Down/MDM and reconnects all
// drive this same instance. updateManager isn't ready yet (created in
// Start) and is injected there via SetUpdateManager.
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
agent := &serverAgent{s}
s.sleepHandler = sleephandler.New(agent)
s.startSleepDetector()
@@ -147,7 +137,7 @@ func (s *Server) Start() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.clientRunning {
if s.connectClient.ConnectionRunning() {
return nil
}
@@ -165,6 +155,7 @@ func (s *Server) Start() error {
stateMgr := statemanager.New(s.profileManager.GetStatePath())
s.updateManager = updater.NewManager(s.statusRecorder, stateMgr)
s.updateManager.CheckUpdateSuccess(s.rootCtx)
s.connectClient.SetUpdateManager(s.updateManager)
}
// MDM policy reload ticker: every minute the desktop daemon re-reads
@@ -190,7 +181,9 @@ func (s *Server) Start() error {
return nil
}
ctx, cancel := context.WithCancel(s.rootCtx)
// actCancel cancels in-flight foreground operations (login/status); the run
// itself is owned by the supervisor and stopped via Stop, not this cancel.
_, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
// copy old default config
@@ -232,99 +225,14 @@ func (s *Server) Start() error {
return nil
}
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
// Boot autoconnect: no incoming RPC metadata. The supervisor runs the
// client and reconnects internally; we just fire and forget (the run owns
// its established/done channels).
s.connectClient.RunAsync(config, nil)
s.publishConfigChangedEvent("startup")
return nil
}
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
//
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
// — placed in the function-scope defer so every return path (panic,
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
// it. Callers that need to observe "is the goroutine still alive?" use
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
// goroutine.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
if giveUpChan != nil {
close(giveUpChan)
}
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
return
}
backOff := getConnectWithBackoff(ctx)
go func() {
t := time.NewTicker(24 * time.Hour)
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
mgmtState := statusRecorder.GetManagementState()
signalState := statusRecorder.GetSignalState()
if mgmtState.Connected && signalState.Connected {
log.Tracef("resetting status")
backOff.Reset()
} else {
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
}
}
}
}()
runOperation := func() error {
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
}
log.Tracef("client connection exited gracefully, do not need to retry")
return nil
}
if err := backoff.Retry(runOperation, backOff); err != nil {
log.Errorf("operation failed: %v", err)
}
// giveUpChan is closed by the function-scope defer.
}
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
// still running. Returns false when no goroutine has ever been started
// AND when the most recent one has already closed clientGiveUpChan on
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
// completion, or backoff retry exhaustion).
//
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
// is written by Start/Up under the same lock.
func (s *Server) connectionGoroutineRunning() bool {
if s.clientGiveUpChan == nil {
return false
}
select {
case <-s.clientGiveUpChan:
return false
default:
return true
}
}
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
@@ -720,13 +628,22 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
// Up starts engine work in the daemon.
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
s.mutex.Lock()
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
// goroutine is still trying. When intent is up AND goroutine is alive,
// the existing engine is on the job — just wait for it. When intent
// is up but the goroutine has given up (backoff exhausted) OR when
// intent is down, fall through to spawn a fresh retry loop.
if s.clientRunning && s.connectionGoroutineRunning() {
// The client (and its supervisor) is built once in New(), so a nil here
// never happens in production — Up is only reachable after New() has run and
// the gRPC server is serving. The real case this guards is the daemon
// SHUTTING DOWN: rootCtx is cancelled, the supervisor is no longer accepting
// commands, so ServiceRunning() is false even though the client exists. Bail
// loud instead of enqueuing a run that will never start. (nil only happens in
// tests that build a Server without New(); ServiceRunning is nil-safe.)
if !s.connectClient.ServiceRunning() {
s.mutex.Unlock()
return nil, fmt.Errorf("service is not running, start the netbird service for 'up' to take effect")
}
// If a connection run is already in flight, the existing engine is on the
// job — just wait for it. Otherwise fall through to start a fresh run.
if s.connectClient.ConnectionRunning() {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
@@ -764,14 +681,14 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
if s.actCancel != nil {
s.actCancel()
}
ctx, cancel := context.WithCancel(s.rootCtx)
md, ok := metadata.FromIncomingContext(callerCtx)
if ok {
ctx = metadata.NewOutgoingContext(ctx, md)
}
// actCancel cancels in-flight foreground ops (login/status); the run is
// owned by the supervisor and stopped via Stop, not this cancel.
_, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
// Forward the caller's gRPC metadata (e.g. UI user-agent) into the run.
md, _ := metadata.FromIncomingContext(callerCtx)
if s.config == nil {
s.mutex.Unlock()
return nil, fmt.Errorf("config is not defined, please call login command first")
@@ -812,35 +729,26 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
s.connectClient.RunAsync(s.config, md)
s.publishConfigChangedEvent("up_rpc")
s.mutex.Unlock()
return s.waitForUp(callerCtx)
}
// todo: handle potential race conditions
// waitForUp blocks until the in-flight run becomes established (success) or ends
// before that (failure). The wait is owned by the supervisor (via the client) —
// the daemon holds no per-run state here.
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel()
select {
case <-s.clientGiveUpChan:
return nil, fmt.Errorf("client gave up to connect")
case <-s.clientRunningChan:
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil
case <-callerCtx.Done():
log.Debug("context done, stopping the wait for engine to become ready")
return nil, callerCtx.Err()
case <-timeoutCtx.Done():
log.Debug("up is timed out, stopping the wait for engine to become ready")
return nil, timeoutCtx.Err()
if err := s.connectClient.WaitEstablishedOrDone(timeoutCtx); err != nil {
log.Debugf("waiting for the connection to be established failed: %v", err)
return nil, fmt.Errorf("connection not established: %w", err)
}
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil
}
// resolveProfileHandle resolves a wire-level profile handle (display
@@ -935,11 +843,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
// Down engine work in the daemon.
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
giveUpChan := s.clientGiveUpChan
// cleanupConnection stops the run through the supervisor, which blocks until
// the run has fully unwound — no separate goroutine-quiescence wait needed.
if err := s.cleanupConnection(); err != nil {
s.mutex.Unlock()
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
@@ -948,20 +856,6 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
s.mutex.Unlock()
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
// The giveUpChan is closed at the end of connectWithRetryRuns.
if giveUpChan != nil {
select {
case <-giveUpChan:
log.Debugf("client goroutine finished successfully")
case <-time.After(5 * time.Second):
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
}
}
return &proto.DownResponse{}, nil
}
@@ -972,38 +866,19 @@ func (s *Server) cleanupConnection() error {
return ErrServiceNotUp
}
// Daemon intent flips to "down" — all callers (Down RPC,
// Logout RPC handlers) tear down the connection because the user
// explicitly asked for it. MDM restart does NOT go through this
// path, so its clientRunning stays true.
s.clientRunning = false
// Capture the engine reference before cancelling the context.
// After actCancel(), the connectWithRetryRuns goroutine wakes up
// and sets connectClient.engine = nil, causing connectClient.Stop()
// to skip the engine shutdown entirely.
var engine *internal.Engine
if s.connectClient != nil {
engine = s.connectClient.Engine()
// Tear the client down through the lifecycle supervisor BEFORE cancelling
// the retry context. Stop serializes on the supervisor queue and blocks
// until the in-flight run has fully unwound (a clean, synchronous teardown).
// It must run before actCancel: cancelling the context first would make
// Stop observe a dead context and return early without waiting.
if err := s.connectClient.Stop(); err != nil {
return err
}
// Stop the retry goroutine so it does not start a fresh run. The client
// itself is daemon-lifetime and intentionally kept (a later Up reuses it).
s.actCancel()
if s.connectClient == nil {
return nil
}
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
// actCancel() lets the run loop stop the engine too, so both stop it
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
// making the run loop the sole owner of engine shutdown.
if engine != nil {
if err := engine.Stop(); err != nil {
return err
}
}
s.connectClient = nil
s.isSessionActive.Store(false)
log.Infof("service is down")
@@ -1138,7 +1013,7 @@ func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfi
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
if err == nil && activeProf.ID == profile.ID && s.connectClient.ConnectionRunning() {
return s.sendLogoutRequest(ctx)
}
@@ -1184,48 +1059,13 @@ func (s *Server) Status(
ctx context.Context,
msg *proto.StatusRequest,
) (*proto.StatusResponse, error) {
s.mutex.Lock()
// Only wait if the retry-loop goroutine is alive and making
// progress. clientRunning=true with connectionGoroutineRunning=false means the
// backoff has given up — there is nothing to wait for; let the
// caller observe the failed status directly.
alive := s.connectionGoroutineRunning()
s.mutex.Unlock()
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
return nil, err
}
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-s.clientGiveUpChan:
ticker.Stop()
break loop
case <-s.clientRunningChan:
ticker.Stop()
break loop
case <-ticker.C:
status, err := state.Status()
if err != nil {
continue
}
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
continue
case <-ctx.Done():
return nil, ctx.Err()
}
// A run that hits a terminal auth failure now exits on its own (engine marks
// NeedsLogin), so we no longer poll-and-cancel: we wait for the in-flight run
// to become established or to end. With no run in flight this returns
// immediately (errNoRunInFlight); either way we then report the status below.
if msg.WaitForReady != nil && *msg.WaitForReady {
if err := s.connectClient.WaitEstablishedOrDone(ctx); err != nil && ctx.Err() != nil {
return nil, ctx.Err()
}
}
@@ -1263,10 +1103,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
@@ -1304,10 +1140,6 @@ func (s *Server) GetPeerSSHHostKey(
statusRecorder := s.statusRecorder
s.mutex.Unlock()
if connectClient == nil {
return nil, errors.New("client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return nil, errors.New("engine not started")
@@ -1474,17 +1306,13 @@ func (s *Server) WaitJWTToken(
// ExposeService exposes a local port via the NetBird reverse proxy.
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
s.mutex.Lock()
if !s.clientRunning {
if !s.connectClient.ConnectionRunning() {
s.mutex.Unlock()
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
}
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
@@ -1538,10 +1366,6 @@ func isUnixRunningDesktop() bool {
}
func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil {
return
}
engine := s.connectClient.Engine()
if engine == nil {
return
@@ -1820,22 +1644,6 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
log.Tracef("running client connection")
client := internal.NewConnectClient(ctx, config, statusRecorder)
client.SetUpdateManager(s.updateManager)
client.SetSyncResponsePersistence(s.persistSyncResponse)
s.mutex.Lock()
s.connectClient = client
s.mutex.Unlock()
if err := client.Run(runningChan, s.logFile); err != nil {
return err
}
return nil
}
// MDM authority: when the platform-native MDM source sets a kill switch
// key (regardless of true/false value), that value wins. The CLI flag
// supplied at service install time is the fallback used only when the
@@ -1897,45 +1705,6 @@ func (s *Server) onSessionExpire() {
}
}
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}
// sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error {

View File

@@ -15,14 +15,19 @@ import (
)
func newTestServer() *Server {
return &Server{
rootCtx: context.Background(),
ctx := context.Background()
s := &Server{
rootCtx: ctx,
statusRecorder: peer.NewRecorder(""),
}
// Honor the production invariant: the daemon-lifetime client always exists
// (built in New). Server methods rely on s.connectClient being non-nil.
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
return s
}
func newDummyConnectClient(ctx context.Context) *internal.ConnectClient {
return internal.NewConnectClient(ctx, nil, nil)
return internal.NewConnectClient(ctx, nil)
}
// TestConnectSetsClientWithMutex validates that connect() sets s.connectClient
@@ -87,41 +92,36 @@ func TestConcurrentConnectClientAccess(t *testing.T) {
assert.Equal(t, 50, nilCount+setCount, "all goroutines should complete without panic")
}
// TestCleanupConnection_ClearsConnectClient validates that cleanupConnection
// properly nils out connectClient.
func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
// TestCleanupConnection_KeepsClientStopsRunning validates that cleanupConnection
// clears the daemon "up" intent but KEEPS the daemon-lifetime ConnectClient
// (it is reused across Up/Down; only the run is stopped).
func TestCleanupConnection_KeepsClientStopsRunning(t *testing.T) {
s := newTestServer()
_, cancel := context.WithCancel(context.Background())
s.actCancel = cancel
s.connectClient = newDummyConnectClient(context.Background())
s.clientRunning = true
err := s.cleanupConnection()
require.NoError(t, err)
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
assert.NotNil(t, s.connectClient, "connectClient is daemon-lifetime and must persist after cleanup")
assert.False(t, s.connectClient.ConnectionRunning(), "no run should be in flight after cleanup")
}
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
// when connectClient is nil.
func TestCleanState_NilConnectClient(t *testing.T) {
// TestCleanState_NotConnected validates that CleanState doesn't panic when no
// connection run is in flight.
func TestCleanState_NotConnected(t *testing.T) {
s := newTestServer()
s.connectClient = nil
s.profileManager = nil // will cause error if it tries to proceed past the nil check
s.profileManager = nil // will cause error if it tries to proceed
// Should not panic — the nil check should prevent calling Status() on nil
assert.NotPanics(t, func() {
_, _ = s.CleanState(context.Background(), &proto.CleanStateRequest{All: true})
})
}
// TestDeleteState_NilConnectClient validates that DeleteState doesn't panic
// when connectClient is nil.
func TestDeleteState_NilConnectClient(t *testing.T) {
// TestDeleteState_NotConnected validates that DeleteState doesn't panic when no
// connection run is in flight.
func TestDeleteState_NotConnected(t *testing.T) {
s := newTestServer()
s.connectClient = nil
s.profileManager = nil
assert.NotPanics(t, func() {
@@ -129,60 +129,6 @@ func TestDeleteState_NilConnectClient(t *testing.T) {
})
}
// TestDownThenUp_StaleRunningChan documents the known state issue where
// clientRunningChan from a previous connection is already closed, causing
// waitForUp() to return immediately on reconnect.
func TestDownThenUp_StaleRunningChan(t *testing.T) {
s := newTestServer()
// Simulate state after a successful connection
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
close(s.clientRunningChan) // closed when engine started
s.clientGiveUpChan = make(chan struct{})
s.connectClient = newDummyConnectClient(context.Background())
_, cancel := context.WithCancel(context.Background())
s.actCancel = cancel
// Simulate Down(): cleanupConnection sets connectClient = nil and
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
// remains independent of intent — derived from clientGiveUpChan.
s.mutex.Lock()
err := s.cleanupConnection()
s.mutex.Unlock()
require.NoError(t, err)
// After cleanup: connectClient is nil, clientRunning is false (intent
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
// (goroutine teardown is independent of the intent flag).
s.mutex.Lock()
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
s.mutex.Unlock()
// waitForUp() returns immediately due to stale closed clientRunningChan
ctx, ctxCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer ctxCancel()
waitDone := make(chan error, 1)
go func() {
_, err := s.waitForUp(ctx)
waitDone <- err
}()
select {
case err := <-waitDone:
assert.NoError(t, err, "waitForUp returns success on stale channel")
// But connectClient is still nil — this is the stale state issue
s.mutex.Lock()
assert.Nil(t, s.connectClient, "connectClient is nil despite waitForUp success")
s.mutex.Unlock()
case <-time.After(1 * time.Second):
t.Fatal("waitForUp should have returned immediately due to stale closed channel")
}
}
// TestConnectClient_EngineNilOnFreshClient validates that a newly created
// ConnectClient has nil Engine (before Run is called).
func TestConnectClient_EngineNilOnFreshClient(t *testing.T) {

View File

@@ -31,7 +31,6 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/server"
@@ -61,65 +60,6 @@ var (
}
)
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal(t)
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
ic := profilemanager.ConfigInput{
ManagementURL: "http://" + mgmtAddr,
ConfigPath: t.TempDir() + "/test-profile.json",
}
config, err := profilemanager.UpdateOrCreateConfig(ic)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
currUser, err := user.Current()
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false, false, false)
s.config = config
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}
func TestServer_Up(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir

View File

@@ -9,7 +9,6 @@ import (
"google.golang.org/grpc/status"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/proto"
@@ -38,7 +37,7 @@ func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*pro
// CleanState handles cleaning of states (performing cleanup operations)
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
if s.connectClient.ConnectionRunning() {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}
@@ -81,7 +80,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (
// DeleteState handles deletion of states without cleanup
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
if s.connectClient.ConnectionRunning() {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}

View File

@@ -62,10 +62,6 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
}
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
if s.connectClient == nil {
return nil, nil, fmt.Errorf("connect client not initialized")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, nil, fmt.Errorf("engine not initialized")

View File

@@ -3,7 +3,6 @@ package system
import (
"context"
"net/netip"
"slices"
"strings"
log "github.com/sirupsen/logrus"
@@ -122,23 +121,6 @@ func (i *Info) SetFlags(
}
}
// removeAddresses drops network addresses whose IP matches any of the given
// addresses, regardless of prefix length. Used to exclude the NetBird overlay
// address, which otherwise churns the meta as the interface comes and goes.
func (i *Info) removeAddresses(ips ...netip.Addr) {
if len(ips) == 0 {
return
}
filtered := i.NetworkAddresses[:0]
for _, addr := range i.NetworkAddresses {
if slices.Contains(ips, addr.NetIP.Addr()) {
continue
}
filtered = append(filtered, addr)
}
i.NetworkAddresses = filtered
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx)
@@ -165,9 +147,7 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
// excludeIPs are dropped from the reported network addresses (e.g. our own
// WireGuard overlay address, which otherwise churns the peer meta).
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, error) {
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))
processCheckPaths := make([]string, 0)
for _, check := range checks {
@@ -182,7 +162,6 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
info := GetInfo(ctx)
info.Files = files
info.removeAddresses(excludeIPs...)
log.Debugf("all system information gathered successfully")
return info, nil

View File

@@ -2,7 +2,6 @@ package system
import (
"context"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
@@ -44,42 +43,3 @@ func Test_NetAddresses(t *testing.T) {
t.Errorf("no network addresses found")
}
}
func TestInfo_RemoveAddresses(t *testing.T) {
addr := func(cidr string) NetworkAddress {
return NetworkAddress{NetIP: netip.MustParsePrefix(cidr)}
}
info := &Info{
NetworkAddresses: []NetworkAddress{
addr("192.168.1.7/24"),
addr("100.76.70.97/32"), // overlay v4 (host mask /32)
addr("2001:818:c51b:4800:845:a65d:ae6f:623f/64"), // real global v6
addr("fd00:1234::1/64"), // overlay v6
},
}
// Overlay addresses as the engine knows them, with a different mask (/16, /64).
info.removeAddresses(
netip.MustParseAddr("100.76.70.97"),
netip.MustParseAddr("fd00:1234::1"),
)
want := []string{"192.168.1.7/24", "2001:818:c51b:4800:845:a65d:ae6f:623f/64"}
if len(info.NetworkAddresses) != len(want) {
t.Fatalf("got %d addresses, want %d: %v", len(info.NetworkAddresses), len(want), info.NetworkAddresses)
}
for i, w := range want {
if got := info.NetworkAddresses[i].NetIP.String(); got != w {
t.Errorf("address[%d] = %s, want %s", i, got, w)
}
}
}
func TestInfo_RemoveAddresses_NoOp(t *testing.T) {
info := &Info{NetworkAddresses: []NetworkAddress{{NetIP: netip.MustParsePrefix("10.0.0.1/24")}}}
info.removeAddresses()
if len(info.NetworkAddresses) != 1 {
t.Errorf("expected no change with empty input, got %v", info.NetworkAddresses)
}
}

View File

@@ -46,9 +46,7 @@ func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
if !ok {
return NetworkAddress{}, false
}
// Skip link-local and multicast: they carry no routable peer info and the
// IPv6 link-local of a flapping NIC churns the meta on every up/down.
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
if ipNet.IP.IsLoopback() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())

View File

@@ -1,45 +0,0 @@
//go:build !ios
package system
import (
"net"
"testing"
)
func mustIPNet(t *testing.T, cidr string) *net.IPNet {
t.Helper()
ip, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
t.Fatalf("parse %q: %v", cidr, err)
}
ipNet.IP = ip
return ipNet
}
func TestToNetworkAddress_Filtering(t *testing.T) {
const mac = "c8:4b:d6:b6:04:ac"
tests := []struct {
name string
cidr string
want bool
}{
{"ipv4 global", "10.65.16.181/23", true},
{"ipv6 global", "2620:52:0:4110:102d:6a98:ee75:8b92/64", true},
{"ipv4 loopback", "127.0.0.1/8", false},
{"ipv6 loopback", "::1/128", false},
{"ipv6 link-local", "fe80::871:4c25:23d7:2529/64", false},
{"ipv4 link-local", "169.254.1.2/16", false},
{"ipv6 multicast", "ff02::1/128", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, got := toNetworkAddress(mustIPNet(t, tt.cidr), mac)
if got != tt.want {
t.Errorf("toNetworkAddress(%s) ok = %v, want %v", tt.cidr, got, tt.want)
}
})
}
}

View File

@@ -418,14 +418,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
case args.showProfiles:
s.showProfilesUI()
case args.showQuickActions:
// Suppress the on-boot Quick Actions popup when the daemon
// reports DisableAutoConnect=true — that flag carries both the
// user's "Connect on Startup = off" preference AND any MDM-
// enforced override (applyMDMPolicy writes the policy value
// into the same Config field). See netbirdio/netbird#5744.
if !s.disableAutoConnectFromDaemon() {
s.showQuickActionsUI()
}
s.showQuickActionsUI()
case args.showUpdate:
s.showUpdateProgress(ctx, args.showUpdateVersion)
}
@@ -1345,40 +1338,6 @@ func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
return features, nil
}
// disableAutoConnectFromDaemon returns true when the daemon reports
// the active profile has DisableAutoConnect=true. Used by the
// --quick-actions startup path to suppress the on-boot popup when the
// user (or an MDM admin) opted out of auto-connecting; both cases
// converge on the same Config field because applyMDMPolicy writes the
// policy value into it. Returns false on any RPC / lookup failure so a
// daemon hiccup does not silently swallow the popup.
func (s *serviceClient) disableAutoConnectFromDaemon() bool {
activeProf, err := s.profileManager.GetActiveProfile()
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: get active profile: %v", err)
return false
}
currUser, err := user.Current()
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: get current user: %v", err)
return false
}
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: get daemon client: %v", err)
return false
}
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
})
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: GetConfig RPC: %v", err)
return false
}
return srvCfg.GetDisableAutoConnect()
}
// getSrvConfig from the service to show it in the settings window.
func (s *serviceClient) getSrvConfig() {
s.managementURL = profilemanager.DefaultManagementURL

View File

@@ -0,0 +1,56 @@
# Build environments
Dockerfiles that pin the same toolchain CI uses, so a developer can
reproduce a CI build locally without installing platform SDKs on their
workstation. The version pins in each `Dockerfile` must stay in lockstep
with `.github/workflows/`.
## `android/`
Mirrors `.github/workflows/mobile-build-validation.yml` (`android_build`
job). Carries Go 1.25.5, Adopt JDK 11, Android cmdline-tools 8512546,
NDK 23.1.7779620 and gomobile pinned at the CI commit. Use it to
produce `netbird.aar` from `./client/android`:
```bash
docker build -t netbird/build-android docker/build-env/android
docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
gomobile bind \
-o netbird.aar \
-javapkg=io.netbird.gomobile \
-ldflags="-checklinkname=0 \
-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
-X github.com/netbirdio/netbird/version.version=local" \
./client/android
```
To build the full Android APK, bind-mount the `android-client` repo as
well and run its own `./gradlew assembleDebug` from inside the
container (the gradle wrapper ships with `android-client`).
## `windows-cross/`
Cross-compiles Windows binaries from Linux using `mingw-w64`. Lets you
verify that `GOOS=windows go build ./...` compiles cleanly without
needing a Windows VM. Cannot run Windows tests — the `golang-test-windows`
CI job executes on a native `windows-latest` runner with wintun.dll
and PsExec, neither of which lives under Linux containers.
```bash
docker build -t netbird/build-windows docker/build-env/windows-cross
docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
```
## What is NOT here
- **iOS / macOS**: cannot legally run macOS in Docker (Apple EULA),
and Xcode is not redistributable. The `ios_build` CI job uses a
`macos-latest` GitHub runner; locally you need a real Mac.
- **Native Windows tests**: see note above. The Linux+mingw image
builds, it does not execute Windows-host code paths
(registry, wintun, services, PsExec workflows).
When CI version pins change, update the corresponding `ARG` lines in
the Dockerfiles and the README's table of versions.

View File

@@ -0,0 +1,86 @@
# Android build environment.
#
# Mirrors the toolchain pinned by .github/workflows/mobile-build-validation.yml
# so a `gomobile bind` against ./client/android in this image produces the
# same netbird.aar that CI builds.
#
# Tooling versions (must stay in sync with the CI workflow):
# - Ubuntu 22.04 (matches the ubuntu-latest GitHub runner)
# - Go 1.25.5 (matches go.mod)
# - Adopt JDK 11 (matches actions/setup-java@v3 java-version: 11, distribution: adopt)
# - Android SDK cmdline-tools 8512546
# - Android NDK 23.1.7779620
# - gomobile commit v0.0.0-20251113184115-a159579294ab
#
# Usage (from the netbird repo root):
#
# docker build -t netbird/build-android docker/build-env/android
#
# # bind the netbird checkout in and run the same gomobile command CI runs
# docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
# gomobile bind \
# -o netbird.aar \
# -javapkg=io.netbird.gomobile \
# -ldflags="-checklinkname=0 \
# -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
# -X github.com/netbirdio/netbird/version.version=local" \
# ./client/android
#
# To build the full APK, mount the android-client repo too and run
# `./gradlew assembleDebug` from /android-client (this image carries
# gradle's prerequisites JDK + Android SDK but not the gradle wrapper —
# that ships with android-client).
FROM ubuntu:22.04
ARG DEBIAN_FRONTEND=noninteractive
# Versions — bump in lockstep with .github/workflows/mobile-build-validation.yml.
ARG GO_VERSION=1.25.5
ARG ANDROID_CMDLINE_TOOLS_VERSION=8512546
ARG ANDROID_NDK_VERSION=23.1.7779620
ARG GOMOBILE_VERSION=v0.0.0-20251113184115-a159579294ab
ENV ANDROID_HOME=/opt/android-sdk
ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${ANDROID_NDK_VERSION}
ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
ENV GOPATH=/go
ENV GOTOOLCHAIN=local
ENV CGO_ENABLED=0
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${ANDROID_HOME}/cmdline-tools/latest/bin:${ANDROID_HOME}/platform-tools:${JAVA_HOME}/bin:${PATH}
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
curl \
unzip \
git \
openjdk-11-jdk-headless \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Install Go (matches go.mod). actions/setup-go fetches the same tarball.
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
| tar -C /usr/local -xz \
&& go version
# Install Android SDK command-line tools, accept licenses, install NDK.
RUN mkdir -p "${ANDROID_HOME}/cmdline-tools" \
&& curl -fsSL -o /tmp/cmdline.zip \
"https://dl.google.com/android/repository/commandlinetools-linux-${ANDROID_CMDLINE_TOOLS_VERSION}_latest.zip" \
&& unzip -q /tmp/cmdline.zip -d "${ANDROID_HOME}/cmdline-tools" \
&& mv "${ANDROID_HOME}/cmdline-tools/cmdline-tools" "${ANDROID_HOME}/cmdline-tools/latest" \
&& rm /tmp/cmdline.zip \
&& yes | sdkmanager --licenses > /dev/null \
&& sdkmanager --install "ndk;${ANDROID_NDK_VERSION}" "platform-tools" > /dev/null
# Install gomobile at the same commit CI pins. Don't run `gomobile init` here:
# `init` resolves the NDK at runtime, do it on the first bind in the mounted
# workspace so the cache lands on the host volume.
RUN GOBIN=/usr/local/bin go install "golang.org/x/mobile/cmd/gomobile@${GOMOBILE_VERSION}" \
&& gomobile version
WORKDIR /src
# Default entrypoint is a plain shell so the image is composable: callers pass
# the full gomobile / gradle command they want to run.
CMD ["/bin/bash"]

View File

@@ -0,0 +1,63 @@
# Windows-cross build environment.
#
# Cross-compiles Windows .exe targets from a Linux container using
# mingw-w64. Mirrors the toolchain set used by
# .github/workflows/golang-test-windows.yml insofar as that is possible
# without a Windows kernel.
#
# IMPORTANT — what this image CAN do:
# - `GOOS=windows go build ./...` to validate that Windows builds compile
# - CGO Windows cross-compile via x86_64-w64-mingw32-gcc when CGO_ENABLED=1
# (matches CI's choco-installed mingw-w64)
#
# IMPORTANT — what this image CANNOT do:
# - Run Windows binaries (no Windows kernel under Docker on Linux).
# - Replicate the CI's `go test` runs which execute on a real
# windows-latest runner (wintun.dll, PsExec, registry, etc.).
# Use the CI for that or a native Windows VM.
#
# Usage (from the netbird repo root):
#
# docker build -t netbird/build-windows docker/build-env/windows-cross
#
# # Cross-compile a static client (.exe) from Linux:
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
# bash -c 'CGO_ENABLED=1 GOOS=windows GOARCH=amd64 \
# CC=x86_64-w64-mingw32-gcc CXX=x86_64-w64-mingw32-g++ \
# go build -o netbird.exe ./client'
#
# # Just validate that everything *compiles* on Windows (no CGO):
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
# bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
#
# Tooling versions (keep in sync with go.mod and any future explicit pin
# documented in golang-test-windows.yml):
# - Ubuntu 22.04
# - Go 1.25.5 (matches go.mod)
# - mingw-w64 (Ubuntu package — pin further if drift becomes a problem)
FROM ubuntu:22.04
ARG DEBIAN_FRONTEND=noninteractive
ARG GO_VERSION=1.25.5
ENV GOPATH=/go
ENV GOTOOLCHAIN=local
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${PATH}
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
curl \
git \
build-essential \
mingw-w64 \
&& rm -rf /var/lib/apt/lists/*
# Install Go (matches go.mod).
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
| tar -C /usr/local -xz \
&& go version
WORKDIR /src
CMD ["/bin/bash"]

2
go.mod
View File

@@ -341,7 +341,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

4
go.sum
View File

@@ -510,8 +510,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a h1:3CWK+yTvRKOcC0Q8VCTGy4l60TEb27CQVS7LkMxwjmw=
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=

View File

@@ -1,616 +0,0 @@
#!/bin/bash
set -e
set -o pipefail
# NetBird Enterprise — Getting Started
# Single-node bootstrap for a self-hosted NetBird Enterprise stack with the
# embedded identity provider. Owner is created via first-login flow.
SED_STRIP_PADDING='s/=//g'
check_docker_compose() {
if command -v docker-compose &> /dev/null; then
echo "docker-compose"
return
fi
if docker compose --help &> /dev/null; then
echo "docker compose"
return
fi
echo "docker-compose is not installed or not in PATH. See https://docs.docker.com/engine/install/" > /dev/stderr
exit 1
}
check_openssl() {
if ! command -v openssl &> /dev/null; then
echo "openssl is not installed or not in PATH." > /dev/stderr
exit 1
fi
}
rand_secret() {
openssl rand -base64 32 | sed "$SED_STRIP_PADDING"
}
rand_b64_key() {
openssl rand -base64 32
}
check_nb_domain() {
local domain="$1"
if [[ -z "$domain" ]]; then
echo "The domain cannot be empty." > /dev/stderr
return 1
fi
if [[ "$domain" == "netbird.example.com" ]]; then
echo "The domain cannot be netbird.example.com" > /dev/stderr
return 1
fi
if [[ "$domain" =~ ^[0-9.]+$ ]]; then
echo "An IP address is not allowed. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
return 1
fi
if [[ ! "$domain" =~ ^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)+$ ]]; then
echo "The value '$domain' is not a valid FQDN. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
return 1
fi
return 0
}
check_domain_resolves() {
local domain="$1"
if command -v getent &> /dev/null && getent hosts "$domain" &> /dev/null; then return 0; fi
if command -v host &> /dev/null && host "$domain" &> /dev/null; then return 0; fi
if command -v dig &> /dev/null && [[ -n "$(dig +short "$domain" 2>/dev/null)" ]]; then return 0; fi
if command -v nslookup &> /dev/null && nslookup "$domain" &> /dev/null; then return 0; fi
return 1
}
read_nb_domain() {
local value=""
echo -n "Enter the FQDN for NetBird (must resolve via DNS, e.g. netbird.my-domain.com): " > /dev/stderr
read -r value < /dev/tty
if ! check_nb_domain "$value"; then
read_nb_domain
return
fi
if ! check_domain_resolves "$value"; then
echo "" > /dev/stderr
echo "Warning: '$value' does not resolve via DNS from this host." > /dev/stderr
echo "Caddy will not be able to issue TLS certificates until it does." > /dev/stderr
local confirm=""
echo -n "Continue anyway? [y/N]: " > /dev/stderr
read -r confirm < /dev/tty
if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
read_nb_domain
return
fi
fi
echo "$value"
}
read_required() {
local prompt="$1"
local value=""
while [[ -z "$value" ]]; do
echo -n "$prompt: " > /dev/stderr
read -r value < /dev/tty
if [[ -z "$value" ]]; then
echo "Value cannot be empty." > /dev/stderr
fi
done
echo "$value"
}
read_secret() {
local prompt="$1"
local value=""
while [[ -z "$value" ]]; do
echo -n "$prompt: " > /dev/stderr
read -rs value < /dev/tty
echo "" > /dev/stderr
if [[ -z "$value" ]]; then
echo "Value cannot be empty." > /dev/stderr
fi
done
echo "$value"
}
# read_yes_no "<prompt>" [<default y|n>]
read_yes_no() {
local prompt="$1"
local default="${2:-n}"
local hint
if [[ "$default" == "y" ]]; then
hint="[Y/n]"
else
hint="[y/N]"
fi
echo -n "${prompt} ${hint}: " > /dev/stderr
local ans=""
read -r ans < /dev/tty
if [[ -z "$ans" ]]; then
ans="$default"
fi
case "$ans" in
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
*) echo "no" ;;
esac
}
wait_postgres() {
set +e
echo -n "Waiting for postgres to become ready"
local counter=1
while true; do
if $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" &> /dev/null; then
break
fi
if [[ $counter -eq 60 ]]; then
echo ""
echo "Postgres is taking too long. Recent logs:"
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
exit 1
fi
echo -n " ."
sleep 2
counter=$((counter + 1))
done
echo " done"
set -e
}
init_environment() {
check_openssl
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
if [[ -f .env ]] || [[ -f docker-compose.yml ]] || [[ -f config.yaml ]] || [[ -f Caddyfile ]]; then
echo "Generated files already exist in $(pwd)."
echo "If you want to reinitialize the environment, please remove them first:"
echo " $DOCKER_COMPOSE_COMMAND down --volumes # removes all containers and volumes"
echo " rm -f .env docker-compose.yml Caddyfile config.yaml"
echo "Be aware this will remove all data from the database."
exit 1
fi
echo "NetBird Enterprise bootstrap"
echo ""
echo "Traffic flow:"
echo " Enables traffic events logging on the management server."
echo " When enabled, the NetBird stack also runs NATS along with two"
echo " additional containers: netbird-receiver (the traffic log receiver"
echo " service) and netbird-enricher (the traffic log enricher service)."
echo " It still has to be turned on from the dashboard settings afterwards."
echo " See https://docs.netbird.io/manage/activity/traffic-events-logging"
NETBIRD_TRAFFIC_FLOW=$(read_yes_no "Enable traffic flow" "n")
echo ""
NETBIRD_DOMAIN=$(read_nb_domain)
echo ""
NETBIRD_LICENSE_KEY=$(read_secret "Enter license key (input hidden)")
GHCR_USERNAME="netbirdExtAccess1"
GHCR_TOKEN=$(read_secret "Enter GHCR token (input hidden)")
POSTGRES_USER="netbird"
POSTGRES_DB="netbird"
POSTGRES_PASSWORD=$(rand_secret)
NETBIRD_ENCRYPTION_KEY=$(rand_b64_key)
NETBIRD_RELAY_AUTH_SECRET=$(rand_secret)
POSTGRES_DSN="host=postgres user=${POSTGRES_USER} password=${POSTGRES_PASSWORD} dbname=${POSTGRES_DB} port=5432 sslmode=disable TimeZone=UTC"
NETBIRD_RELAY_ENDPOINT="rels://${NETBIRD_DOMAIN}:443"
echo ""
echo "Selected:"
echo " Traffic flow: ${NETBIRD_TRAFFIC_FLOW}"
echo " Domain: ${NETBIRD_DOMAIN}"
echo ""
echo "Rendering files into $(pwd) ..."
install -m 600 /dev/null .env
render_env >> .env
render_docker_compose > docker-compose.yml
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' docker-compose.yml && rm -f docker-compose.yml.bak
fi
render_caddyfile > Caddyfile
install -m 600 /dev/null config.yaml
render_config_yaml >> config.yaml
echo "Logging in to ghcr.io ..."
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
unset GHCR_TOKEN
echo ""
echo "Pulling images ..."
$DOCKER_COMPOSE_COMMAND pull
echo ""
echo "Starting postgres ..."
$DOCKER_COMPOSE_COMMAND up -d postgres
sleep 2
wait_postgres
echo ""
echo "Starting remaining services ..."
$DOCKER_COMPOSE_COMMAND up -d
echo ""
echo "Done."
echo ""
echo "Dashboard: https://${NETBIRD_DOMAIN}"
echo ""
echo "Open the dashboard in a browser to complete the first-login owner setup."
echo "All configuration and secrets are stored (mode 600) in $(pwd)/.env"
echo ""
echo "Tail logs:"
echo " cd $(pwd) && $DOCKER_COMPOSE_COMMAND logs -f netbird-server caddy"
}
# ------------------------------------------------------------------
# Renderers
# ------------------------------------------------------------------
render_env() {
cat <<EOF
# Generated by getting-started-enterprise.sh
# Holds all configuration and secrets for the stack. Mode 600.
# Features (set by the script; don't edit without re-running)
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
# Domain
NETBIRD_DOMAIN=${NETBIRD_DOMAIN}
# Image tags. Default to "latest"
NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-latest}
NETBIRD_SERVER_TAG=${NETBIRD_SERVER_TAG:-latest}
EOF
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
cat <<EOF
NETBIRD_ENRICHER_TAG=${NETBIRD_ENRICHER_TAG:-latest}
NETBIRD_RECEIVER_TAG=${NETBIRD_RECEIVER_TAG:-latest}
EOF
fi
cat <<EOF
# License keys
EOF
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
cat <<EOF
NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
EOF
fi
cat <<EOF
NETBIRD_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
EOF
cat <<EOF
# Postgres
POSTGRES_USER=${POSTGRES_USER}
POSTGRES_DB=${POSTGRES_DB}
POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
NETBIRD_STORE_ENGINE_POSTGRES_DSN=${POSTGRES_DSN}
# Relay
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT}
NETBIRD_RELAY_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
# Datastore encryption
NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}
# Dashboard OIDC scopes
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-openid profile email groups}
EOF
}
render_docker_compose() {
render_compose_header
render_compose_common
render_compose_server
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
render_compose_flow
fi
render_compose_postgres
render_compose_footer
}
render_compose_header() {
cat <<'EOF'
x-default: &default
restart: unless-stopped
logging:
driver: json-file
options:
max-size: '500m'
max-file: '2'
services:
EOF
}
render_compose_common() {
cat <<'EOF'
caddy:
<<: *default
image: caddy:2
container_name: netbird-caddy
networks: [netbird]
environment:
- CADDY_SECURE_DOMAIN=${NETBIRD_DOMAIN}
ports:
- '443:443'
- '443:443/udp'
- '80:80'
volumes:
- netbird_caddy_data:/data
- ./Caddyfile:/etc/caddy/Caddyfile
dashboard:
<<: *default
image: ghcr.io/netbirdio/dashboard-cloud:${NETBIRD_DASHBOARD_TAG}
container_name: netbird-dashboard
networks: [netbird]
environment:
- NETBIRD_MGMT_API_ENDPOINT=https://${NETBIRD_DOMAIN}
- NETBIRD_MGMT_GRPC_API_ENDPOINT=https://${NETBIRD_DOMAIN}
- AUTH_AUDIENCE=netbird-dashboard
- AUTH_CLIENT_ID=netbird-dashboard
- AUTH_CLIENT_SECRET=
- AUTH_AUTHORITY=https://${NETBIRD_DOMAIN}/oauth2
- USE_AUTH0=false
- AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES}
- AUTH_REDIRECT_URI=/nb-auth
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
- NETBIRD_TOKEN_SOURCE=accessToken
- NGINX_SSL_PORT=443
- LETSENCRYPT_DOMAIN=
- LETSENCRYPT_EMAIL=
EOF
}
render_compose_server() {
cat <<'EOF'
netbird-server:
<<: *default
image: ghcr.io/netbirdio/netbird-server-cloud:${NETBIRD_SERVER_TAG}
container_name: netbird-server
networks: [netbird]
depends_on:
dashboard:
condition: service_started
postgres:
condition: service_healthy
ports:
- '3478:3478/udp'
volumes:
- netbird_data:/var/lib/netbird
- ./config.yaml:/etc/netbird/config.yaml
command: ["--config", "/etc/netbird/config.yaml"]
environment:
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
EOF
}
render_compose_flow() {
cat <<'EOF'
nats:
<<: *default
image: nats:2
container_name: netbird-nats
networks: [netbird]
volumes:
- netbird_nats_data:/data
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
enricher:
<<: *default
image: ghcr.io/netbirdio/flow-enricher-cloud:${NETBIRD_ENRICHER_TAG}
container_name: netbird-enricher
networks: [netbird]
depends_on:
postgres:
condition: service_healthy
nats:
condition: service_started
volumes:
- netbird_enricher:/var/lib/netbird
environment:
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
- NB_DATADIR=/var/lib/netbird
- NB_MANAGEMENT_STORE_ENGINE=postgres
- NB_MANAGEMENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
- NB_TRAFFIC_EVENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
- NB_TRAFFIC_EVENT_STORE_ENGINE=postgres
- NB_MANAGEMENT_STORE_KEY=${NETBIRD_ENCRYPTION_KEY}
- NB_FLOW_ADAPTER_TYPE=nats
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
- NB_FLOW_NATS_STREAM=traffic-events
- NB_METRICS_PORT=9091
- NB_PERSISTENCE_RETENTION_PERIOD=168h
receiver:
<<: *default
image: ghcr.io/netbirdio/flow-receiver-cloud:${NETBIRD_RECEIVER_TAG}
container_name: netbird-receiver
networks: [netbird]
depends_on:
nats:
condition: service_started
environment:
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
- NB_FLOW_LISTEN_PORT=80
- NB_FLOW_ADAPTER_TYPE=nats
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
- NB_FLOW_NATS_STREAM=traffic-events
- NB_FLOW_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
EOF
}
render_compose_postgres() {
cat <<'EOF'
postgres:
<<: *default
image: postgres:17
container_name: netbird-postgres
networks: [netbird]
environment:
- POSTGRES_USER=${POSTGRES_USER}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
- POSTGRES_DB=${POSTGRES_DB}
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
interval: 10s
timeout: 5s
retries: 10
volumes:
- netbird_postgres:/var/lib/postgresql/data
EOF
}
render_compose_footer() {
cat <<'EOF'
volumes:
netbird_data:
EOF
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
cat <<'EOF'
netbird_nats_data:
netbird_enricher:
EOF
fi
cat <<'EOF'
netbird_postgres:
netbird_caddy_data:
networks:
netbird:
EOF
}
render_caddyfile() {
cat <<'EOF'
{
servers :80,:443 {
protocols h1 h2c h2 h3
}
}
(security_headers) {
header * {
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
X-Content-Type-Options "nosniff"
X-Frame-Options "SAMEORIGIN"
X-XSS-Protection "1; mode=block"
-Server
Referrer-Policy strict-origin-when-cross-origin
}
}
:80 {
redir https://{$CADDY_SECURE_DOMAIN}{uri} permanent
}
{$CADDY_SECURE_DOMAIN}:443 {
import security_headers
# Signal (gRPC over h2c)
reverse_proxy /signalexchange.SignalExchange/* h2c://netbird-server:80
# Management (gRPC over h2c + HTTP)
reverse_proxy /management.ManagementService/* h2c://netbird-server:80
reverse_proxy /api/* netbird-server:80
reverse_proxy /ws-proxy/* netbird-server:80
# Embedded IdP (OAuth2 endpoints served by netbird server)
reverse_proxy /oauth2/* netbird-server:80
# Relay (WebSocket multiplexed on the same port)
reverse_proxy /relay* netbird-server:80
EOF
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
cat <<'EOF'
# Flow receiver (gRPC over h2c)
reverse_proxy /flow.FlowService/* h2c://receiver:80
EOF
fi
cat <<'EOF'
# Dashboard
reverse_proxy /* dashboard:80
}
EOF
}
render_config_yaml() {
cat <<EOF
# NetBird Enterprise server configuration.
# Generated by getting-started-enterprise.sh. Mode 600.
server:
listenAddress: ":80"
exposedAddress: "https://${NETBIRD_DOMAIN}:443"
metricsPort: 9090
healthcheckAddress: ":9000"
logLevel: "info"
logFile: "console"
# TLS is terminated by Caddy in front; leave this block empty.
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
authSecret: "${NETBIRD_RELAY_AUTH_SECRET}"
dataDir: "/var/lib/netbird/"
disableAnonymousMetrics: false
disableGeoliteUpdate: false
auth:
issuer: "https://${NETBIRD_DOMAIN}/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
dashboardRedirectURIs:
- "https://${NETBIRD_DOMAIN}/nb-auth"
- "https://${NETBIRD_DOMAIN}/nb-silent-auth"
cliRedirectURIs:
- "http://localhost:53000/"
store:
engine: "postgres"
dsn: "${POSTGRES_DSN}"
encryptionKey: "${NETBIRD_ENCRYPTION_KEY}"
activityStore:
engine: "postgres"
dsn: "${POSTGRES_DSN}"
EOF
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
cat <<EOF
trafficFlow:
enabled: true
address: "https://${NETBIRD_DOMAIN}:443"
interval: "60s"
EOF
fi
}
init_environment

View File

@@ -351,11 +351,6 @@ initialize_default_values() {
NETBIRD_STUN_PORT=3478
# Docker images
# Record whether the operator explicitly pinned the server/proxy images via
# env vars, so the agent-network preset can pick its own defaults without
# clobbering an explicit override.
NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true}
NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true}
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
# Combined server replaces separate signal, relay, and management containers
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
@@ -403,53 +398,7 @@ configure_domain() {
return 0
}
apply_agent_network_preset() {
# Agent-network turnkey install: built-in Traefik + NetBird Proxy with
# NB_PROXY_PRIVATE=true, dashboard locked to agent-network-only mode.
# Bypasses every reverse-proxy / proxy / CrowdSec prompt. The only
# inputs we still need from the operator are the domain (handled by
# configure_domain via NETBIRD_DOMAIN env var or interactive prompt)
# and the ACME email — both honor env vars first and fall back to a
# prompt only when unset. CrowdSec is intentionally off.
REVERSE_PROXY_TYPE="0"
ENABLE_PROXY="true"
ENABLE_CROWDSEC="false"
# Agent-network ships dedicated server/proxy images. Honor an explicit
# env override; otherwise pin the agent-network builds.
if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.2"
fi
if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.2"
fi
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
else
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
fi
echo "" > /dev/stderr
echo "Agent-network preset enabled (NETBIRD_AGENT_NETWORK=true):" > /dev/stderr
echo " - reverse proxy: built-in Traefik" > /dev/stderr
echo " - NetBird Proxy: enabled with NB_PROXY_PRIVATE=true" > /dev/stderr
echo " - server image: ${NETBIRD_SERVER_IMAGE}" > /dev/stderr
echo " - proxy image: ${NETBIRD_PROXY_IMAGE}" > /dev/stderr
echo " - dashboard: NETBIRD_AGENT_NETWORK_ONLY=true" > /dev/stderr
echo " - CrowdSec: disabled" > /dev/stderr
echo " - Let's Encrypt email: ${TRAEFIK_ACME_EMAIL}" > /dev/stderr
echo "" > /dev/stderr
}
configure_reverse_proxy() {
# Short-circuit: agent-network preset locks every reverse-proxy /
# proxy / CrowdSec choice and bypasses the interactive prompts.
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
apply_agent_network_preset
return 0
fi
# Prompt for reverse proxy type
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
@@ -961,15 +910,6 @@ NGINX_SSL_PORT=443
# Letsencrypt
LETSENCRYPT_DOMAIN=none
EOF
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
cat <<EOF
# Agent-network preset: dashboard hides the standard NetBird surfaces
# and exposes only the AI Observability + agent-network configuration
# pages. Paired with NB_PROXY_PRIVATE=true on the proxy side.
NETBIRD_AGENT_NETWORK_ONLY=true
EOF
fi
return 0
}
@@ -1006,17 +946,6 @@ NB_PROXY_PROXY_PROTOCOL=true
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
EOF
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
cat <<EOF
# Agent-network preset: turn the proxy into the private reverse-proxy
# ingress for agent-network synth services. Disables the public-facing
# surface so the proxy serves only synth-generated routes (the
# llm_router-driven LLM endpoints) and the per-account inbound
# listeners on the embedded netstack.
NB_PROXY_PRIVATE=true
EOF
fi
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
cat <<EOF
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
@@ -1397,20 +1326,12 @@ print_builtin_traefik_instructions() {
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
fi
echo ""
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
echo "For enterprise environments requiring high availability and advanced integrations,"
echo "consider a commercial on-prem license:"
echo ""
echo " Commercial license: https://netbird.ai/pricing"
echo " Documentation: https://docs.netbird.io/agent-network"
else
echo "This setup is ideal for homelabs and smaller organization deployments."
echo "For enterprise environments requiring high availability and advanced integrations,"
echo "consider a commercial on-prem license or scaling your open source deployment:"
echo ""
echo " Commercial license: https://netbird.io/pricing#on-prem"
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
fi
echo "This setup is ideal for homelabs and smaller organization deployments."
echo "For enterprise environments requiring high availability and advanced integrations,"
echo "consider a commercial on-prem license or scaling your open source deployment:"
echo ""
echo " Commercial license: https://netbird.io/pricing#on-prem"
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
echo ""
if [[ "$ENABLE_PROXY" == "true" ]]; then
echo "NetBird Proxy:"
@@ -1433,11 +1354,6 @@ print_builtin_traefik_instructions() {
echo ""
fi
fi
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
echo "Note: The public domain is only for setting up secure connections."
echo "Your APIs and agent services remain private and are never exposed publicly."
echo ""
fi
return 0
}

View File

@@ -1,638 +0,0 @@
#!/bin/bash
set -e
set -o pipefail
# NetBird — community combined → Enterprise combined migration
#
# Non-destructive migration: produces docker-compose.override.yml (auto-loaded
# by docker compose) and config.yaml.enterprise alongside the operator's
# existing files. Original docker-compose.yml and config.yaml are never
# modified.
#
# Steps (all optional, asked interactively):
# 1. Image swap — replace community images with enterprise cloud images.
# 2. Postgres migration — add Postgres, migrate SQLite data via migrate-store.
# 3. Traffic flow — add NATS + flow-enricher + flow-receiver.
#
# To revert:
# docker compose down
# rm -f docker-compose.override.yml config.yaml.enterprise
# # If Postgres migration was done, also restore the SQLite backup printed
# # at the end of this script's run.
# docker compose up -d
OVERRIDE_FILE="docker-compose.override.yml"
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
check_docker_compose() {
if command -v docker-compose &> /dev/null; then
echo "docker-compose"
return
fi
if docker compose --help &> /dev/null; then
echo "docker compose"
return
fi
echo "docker-compose is not installed or not in PATH." > /dev/stderr
exit 1
}
check_yq() {
if ! command -v yq &> /dev/null; then
cat > /dev/stderr <<'EOF'
yq is required to parse and update YAML safely.
macOS: brew install yq
Linux: https://github.com/mikefarah/yq/releases (download binary into PATH)
Debian: apt-get install yq (Note: must be the mikefarah Go yq, not the Python wrapper.)
EOF
exit 1
fi
if ! yq --version 2>&1 | grep -q "mikefarah"; then
echo "yq is present but appears to be the wrong implementation. The mikefarah Go-based yq is required (https://github.com/mikefarah/yq)." > /dev/stderr
exit 1
fi
}
check_openssl() {
if ! command -v openssl &> /dev/null; then
echo "openssl is not installed or not in PATH." > /dev/stderr
exit 1
fi
}
rand_password() {
openssl rand -hex 32
}
read_required() {
local prompt="$1"
local value=""
while [[ -z "$value" ]]; do
echo -n "$prompt: " > /dev/stderr
read -r value < /dev/tty
if [[ -z "$value" ]]; then
echo "Value cannot be empty." > /dev/stderr
fi
done
echo "$value"
}
read_secret() {
local prompt="$1"
local value=""
while [[ -z "$value" ]]; do
echo -n "$prompt: " > /dev/stderr
read -rs value < /dev/tty
echo "" > /dev/stderr
if [[ -z "$value" ]]; then
echo "Value cannot be empty." > /dev/stderr
fi
done
echo "$value"
}
read_yes_no() {
local prompt="$1"
local default="${2:-n}"
local hint
if [[ "$default" == "y" ]]; then
hint="[Y/n]"
else
hint="[y/N]"
fi
echo -n "${prompt} ${hint}: " > /dev/stderr
local ans=""
read -r ans < /dev/tty
if [[ -z "$ans" ]]; then
ans="$default"
fi
case "$ans" in
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
*) echo "no" ;;
esac
}
# ---------------------------------------------------------------------------
# Detection — read the operator's existing compose to find service names and
# paths we need to override. Bail loudly if shape isn't recognised.
# ---------------------------------------------------------------------------
detect_combined_service() {
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/netbird-server"))) | .[0].key // ""' "$COMPOSE_FILE"
}
detect_dashboard_service() {
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/dashboard"))) | .[0].key // ""' "$COMPOSE_FILE"
}
detect_config_yaml_host_path() {
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/etc/netbird/config.yaml\")) | sub(\":/etc/netbird/config.yaml.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
}
detect_data_volume() {
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/var/lib/netbird\")) | sub(\":/var/lib/netbird.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
}
detect_exposed_address() {
yq eval '.server.exposedAddress // ""' "$CONFIG_YAML_HOST"
}
detect_compose_network() {
local tag
tag=$(yq eval ".services[\"$COMBINED_SERVICE\"].networks | tag" "$COMPOSE_FILE" 2>/dev/null)
case "$tag" in
"!!seq")
yq eval ".services[\"$COMBINED_SERVICE\"].networks[0]" "$COMPOSE_FILE"
;;
"!!map")
yq eval ".services[\"$COMBINED_SERVICE\"].networks | keys | .[0]" "$COMPOSE_FILE"
;;
*)
echo "default"
;;
esac
}
# ---------------------------------------------------------------------------
# Renderers
# ---------------------------------------------------------------------------
# Build docker-compose.override.yml from the steps the operator selected.
# Service names match what we detected on the operator's side.
render_override() {
cat <<EOF
# Generated by migrate-to-enterprise.sh. Mode 644.
# Merged with docker-compose.yml automatically by Docker Compose.
# Remove this file (and config.yaml.enterprise if present) to revert.
services:
${DASHBOARD_SERVICE}:
image: \${NETBIRD_DASHBOARD_IMAGE:-ghcr.io/netbirdio/dashboard-cloud:latest}
${COMBINED_SERVICE}:
image: \${NETBIRD_SERVER_IMAGE:-ghcr.io/netbirdio/netbird-server-cloud:latest}
environment:
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
EOF
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
cat <<EOF
depends_on:
postgres:
condition: service_healthy
volumes:
- ./${ENTERPRISE_CONFIG_FILE}:/etc/netbird/config.yaml.enterprise:ro
command: ["--config", "/etc/netbird/config.yaml.enterprise"]
postgres:
image: postgres:17
container_name: netbird-postgres
restart: unless-stopped
networks: [${COMPOSE_NETWORK}]
environment:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: \${POSTGRES_PASSWORD}
POSTGRES_DB: netbird
volumes:
- netbird_postgres:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U netbird -d netbird"]
interval: 5s
timeout: 5s
retries: 20
EOF
fi
if [[ "$ENABLE_FLOW" == "yes" ]]; then
cat <<EOF
nats:
image: nats:2
container_name: netbird-nats
restart: unless-stopped
networks: [${COMPOSE_NETWORK}]
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
volumes:
- netbird_nats_data:/data
flow-enricher:
image: ghcr.io/netbirdio/flow-enricher-cloud:latest
container_name: netbird-flow-enricher
restart: unless-stopped
networks: [${COMPOSE_NETWORK}]
depends_on:
postgres:
condition: service_healthy
nats:
condition: service_started
environment:
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
NB_DATADIR: /var/lib/netbird
NB_MANAGEMENT_STORE_ENGINE: postgres
NB_MANAGEMENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
NB_STORE_ENGINE_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
NB_TRAFFIC_EVENT_STORE_ENGINE: postgres
NB_TRAFFIC_EVENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
NB_MANAGEMENT_STORE_KEY: \${NETBIRD_ENCRYPTION_KEY}
NB_FLOW_ADAPTER_TYPE: nats
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
NB_FLOW_NATS_STREAM: traffic-events
NB_METRICS_PORT: 9091
NB_PERSISTENCE_RETENTION_PERIOD: 168h
flow-receiver:
image: ghcr.io/netbirdio/flow-receiver-cloud:latest
container_name: netbird-flow-receiver
restart: unless-stopped
networks: [${COMPOSE_NETWORK}]
depends_on:
nats:
condition: service_started
environment:
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
NB_FLOW_LISTEN_PORT: 80
NB_FLOW_ADAPTER_TYPE: nats
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
NB_FLOW_NATS_STREAM: traffic-events
NB_FLOW_AUTH_SECRET: \${NB_FLOW_AUTH_SECRET}
labels:
- traefik.enable=true
- traefik.http.routers.netbird-flow.rule=Host(\`${NETBIRD_HOSTNAME}\`) && PathPrefix(\`/flow.FlowService/\`)
- traefik.http.routers.netbird-flow.entrypoints=websecure
- traefik.http.routers.netbird-flow.tls=true
- traefik.http.routers.netbird-flow.tls.certresolver=letsencrypt
- traefik.http.routers.netbird-flow.service=netbird-flow-h2c
- traefik.http.routers.netbird-flow.priority=100
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.port=80
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.scheme=h2c
EOF
fi
# Volume declarations for anything new the override introduced
local has_volumes="no"
if [[ "$MIGRATE_POSTGRES" == "yes" ]] || [[ "$ENABLE_FLOW" == "yes" ]]; then
has_volumes="yes"
fi
if [[ "$has_volumes" == "yes" ]]; then
cat <<EOF
volumes:
EOF
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo " netbird_postgres:"
fi
if [[ "$ENABLE_FLOW" == "yes" ]]; then
echo " netbird_nats_data:"
fi
fi
}
# Build config.yaml.enterprise by yq-editing the operator's existing
# config.yaml. We don't touch the original file.
render_enterprise_config() {
local pg_dsn="host=postgres user=netbird password=${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
yq eval "
.server.store.engine = \"postgres\" |
.server.store.dsn = \"$pg_dsn\" |
.server.activityStore.engine = \"postgres\" |
.server.activityStore.dsn = \"$pg_dsn\" |
.server.authStore.engine = \"postgres\" |
.server.authStore.dsn = \"$pg_dsn\"
" "$CONFIG_YAML_HOST" > "$ENTERPRISE_CONFIG_FILE"
if [[ "$ENABLE_FLOW" == "yes" ]]; then
local flow_addr="${NETBIRD_DOMAIN}"
yq eval -i "
.server.trafficFlow.enabled = true |
.server.trafficFlow.address = \"$flow_addr\" |
.server.trafficFlow.interval = \"60s\"
" "$ENTERPRISE_CONFIG_FILE"
fi
}
# ---------------------------------------------------------------------------
# Execution steps
# ---------------------------------------------------------------------------
resolve_data_volume() {
local short="$1"
local actual
# Resolve project-prefixed volume name from Docker Compose config first.
actual=$($DOCKER_COMPOSE_COMMAND config 2>/dev/null | yq eval ".volumes.\"$short\".name" - 2>/dev/null)
if [[ -n "$actual" && "$actual" != "null" ]]; then
echo "$actual"
return
fi
# Relative bind mount: docker-compose resolves it against the compose
# file's directory, but `docker run -v` resolves it against the current
# working directory. Normalize to an absolute path so both interpretations
# agree (and the printed revert command works from any CWD).
if [[ "$short" == ./* || "$short" == ../* ]]; then
local compose_dir
compose_dir="$(cd "$(dirname "$COMPOSE_FILE")" && pwd)"
(
cd "$compose_dir"
cd "$(dirname "$short")"
printf '%s/%s\n' "$(pwd)" "$(basename "$short")"
)
return
fi
# Not a named volume (e.g. an absolute bind-mount path) — use it as-is.
echo "$short"
}
backup_sqlite() {
BACKUP_DIR="$(pwd)/backups/sqlite-pre-enterprise-$(date +%Y%m%d-%H%M%S)"
mkdir -p "$BACKUP_DIR"
local data_volume_actual
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
echo "Backing up SQLite store from volume '$data_volume_actual' to $BACKUP_DIR ..."
docker run --rm \
-v "${data_volume_actual}:/var/lib/netbird:ro" \
-v "${BACKUP_DIR}:/backup" \
busybox \
sh -c 'cp -a /var/lib/netbird/. /backup/ 2>/dev/null || true'
local copied
copied=$(find "$BACKUP_DIR" -mindepth 1 | head -1)
if [[ -z "$copied" ]]; then
echo " ⚠ Backup directory is empty — the volume '$data_volume_actual' didn't contain data. Aborting." > /dev/stderr
exit 1
fi
echo " done"
}
run_migrate_store() {
echo "Running migrate-store (SQLite → Postgres) ..."
$DOCKER_COMPOSE_COMMAND run --rm "$COMBINED_SERVICE" migrate-store --config /etc/netbird/config.yaml.enterprise --verify
echo " done"
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
init_migration() {
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
check_yq
check_openssl
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
if [[ ! -f "$COMPOSE_FILE" ]]; then
echo "$COMPOSE_FILE not found in $(pwd)." > /dev/stderr
exit 1
fi
if [[ -f "$OVERRIDE_FILE" ]] || [[ -f "$ENTERPRISE_CONFIG_FILE" ]]; then
echo "Migration artifacts already exist in $(pwd):"
[[ -f "$OVERRIDE_FILE" ]] && echo " $OVERRIDE_FILE"
[[ -f "$ENTERPRISE_CONFIG_FILE" ]] && echo " $ENTERPRISE_CONFIG_FILE"
echo ""
echo "Either you've already migrated, or a previous run was interrupted."
echo "To re-run cleanly: rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
exit 1
fi
COMBINED_SERVICE=$(detect_combined_service)
DASHBOARD_SERVICE=$(detect_dashboard_service)
CONFIG_YAML_HOST=$(detect_config_yaml_host_path)
DATA_VOLUME=$(detect_data_volume)
COMPOSE_NETWORK=$(detect_compose_network)
if [[ -z "$COMBINED_SERVICE" ]]; then
echo "Could not find a service running netbirdio/netbird-server* in $COMPOSE_FILE." > /dev/stderr
echo "This script targets the community combined-server deployment." > /dev/stderr
exit 1
fi
if [[ -z "$DASHBOARD_SERVICE" ]]; then
echo "Could not find a service running netbirdio/dashboard* in $COMPOSE_FILE." > /dev/stderr
exit 1
fi
if [[ -z "$CONFIG_YAML_HOST" ]]; then
echo "Could not find a config.yaml mount on $COMBINED_SERVICE (expected to bind-mount to /etc/netbird/config.yaml)." > /dev/stderr
exit 1
fi
if [[ ! -f "$CONFIG_YAML_HOST" ]]; then
echo "config.yaml host file not found at $CONFIG_YAML_HOST." > /dev/stderr
exit 1
fi
if [[ -z "$DATA_VOLUME" ]]; then
echo "Could not find a volume mounted at /var/lib/netbird on $COMBINED_SERVICE." > /dev/stderr
exit 1
fi
echo "Detected existing deployment:"
echo " Combined service: $COMBINED_SERVICE"
echo " Dashboard: $DASHBOARD_SERVICE"
echo " config.yaml: $CONFIG_YAML_HOST"
echo " Data volume: $DATA_VOLUME"
echo " Network: $COMPOSE_NETWORK"
echo ""
local proceed
proceed=$(read_yes_no "Proceed with migration?" "y")
if [[ "$proceed" != "yes" ]]; then
echo "Aborted."
exit 0
fi
# Step 1 — always (this is the point of the script)
MIGRATE_IMAGES="yes"
echo ""
echo "Step 1: Image swap (community → Enterprise). License key required."
NB_LICENSE_KEY=$(read_secret " License key")
GHCR_USERNAME="netbirdExtAccess1"
GHCR_TOKEN=$(read_secret " GHCR token (input hidden)")
# Step 2 — optional
echo ""
MIGRATE_POSTGRES=$(read_yes_no "Step 2: Migrate storage from SQLite to Postgres? (recommended)" "n")
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo ""
echo " ⚠ Data will be migrated from SQLite to Postgres. The SQLite store"
echo " will be backed up automatically. To fully revert later, restore"
echo " that backup and delete docker-compose.override.yml +"
echo " config.yaml.enterprise."
local confirm
confirm=$(read_yes_no " Continue?" "y")
if [[ "$confirm" != "yes" ]]; then
MIGRATE_POSTGRES="no"
echo " Skipping Postgres migration."
else
POSTGRES_PASSWORD=$(rand_password)
fi
fi
# Step 3 — optional, only if Postgres is on (flow requires Postgres)
echo ""
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
ENABLE_FLOW=$(read_yes_no "Step 3: Enable traffic flow? (requires Postgres)" "n")
if [[ "$ENABLE_FLOW" == "yes" ]]; then
# Auth secret MUST match server.authSecret from config.yaml
NB_FLOW_AUTH_SECRET=$(yq eval '.server.authSecret // ""' "$CONFIG_YAML_HOST")
if [[ -z "$NB_FLOW_AUTH_SECRET" ]] || [[ "$NB_FLOW_AUTH_SECRET" == "null" ]]; then
echo "Could not read server.authSecret from $CONFIG_YAML_HOST." > /dev/stderr
echo "Flow receiver auth must match the combined server's authSecret." > /dev/stderr
exit 1
fi
NETBIRD_DOMAIN=$(detect_exposed_address)
if [[ -z "$NETBIRD_DOMAIN" ]] || [[ "$NETBIRD_DOMAIN" == "null" ]]; then
NETBIRD_DOMAIN=$(read_required " Public NetBird URL (e.g. https://netbird.example.com)")
fi
# Strip protocol + port to leave just the hostname for the Traefik Host() rule.
NETBIRD_HOSTNAME=$(echo "$NETBIRD_DOMAIN" | sed -E 's,^https?://,,' | sed 's,:.*,,' | sed 's,/.*,,')
# We need the encryption key from the existing config.yaml for the enricher
NETBIRD_ENCRYPTION_KEY=$(yq eval '.server.store.encryptionKey // ""' "$CONFIG_YAML_HOST")
if [[ -z "$NETBIRD_ENCRYPTION_KEY" ]] || [[ "$NETBIRD_ENCRYPTION_KEY" == "null" ]]; then
echo "Could not read server.store.encryptionKey from $CONFIG_YAML_HOST." > /dev/stderr
exit 1
fi
fi
else
ENABLE_FLOW="no"
echo "Step 3 (traffic flow) skipped — requires Postgres."
fi
}
apply_changes() {
echo ""
echo "Writing $OVERRIDE_FILE ..."
install -m 644 /dev/null "$OVERRIDE_FILE"
render_override > "$OVERRIDE_FILE"
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' "$OVERRIDE_FILE" && rm -f "$OVERRIDE_FILE.bak"
fi
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo "Writing $ENTERPRISE_CONFIG_FILE ..."
install -m 600 /dev/null "$ENTERPRISE_CONFIG_FILE"
render_enterprise_config
fi
# Persist secrets that the override file references via env interpolation.
# We write them to a .env file in the current directory; docker compose
# picks it up automatically.
echo "Writing .env additions (mode 600) ..."
local ENV_FILE=".env"
touch "$ENV_FILE"
chmod 600 "$ENV_FILE"
{
echo ""
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
fi
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo "POSTGRES_PASSWORD=${POSTGRES_PASSWORD}"
fi
if [[ "$ENABLE_FLOW" == "yes" ]]; then
echo "NB_FLOW_AUTH_SECRET=${NB_FLOW_AUTH_SECRET}"
echo "NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}"
fi
} >> "$ENV_FILE"
echo ""
echo "Logging in to ghcr.io ..."
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
unset GHCR_TOKEN
echo ""
echo "Pulling enterprise images ..."
$DOCKER_COMPOSE_COMMAND pull
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo ""
echo "Stopping existing services (volumes preserved) ..."
$DOCKER_COMPOSE_COMMAND down
backup_sqlite
echo ""
echo "Starting Postgres ..."
$DOCKER_COMPOSE_COMMAND up -d postgres
# Wait for healthy
local counter=0
echo -n "Waiting for Postgres to become ready"
while ! $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U netbird -d netbird &> /dev/null; do
echo -n " ."
sleep 2
counter=$((counter + 1))
if [[ $counter -ge 60 ]]; then
echo ""
echo "Postgres did not become ready in 120s. Recent logs:"
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
exit 1
fi
done
echo " done"
run_migrate_store
fi
echo ""
echo "Bringing up all services ..."
$DOCKER_COMPOSE_COMMAND up -d
echo ""
echo "Migration complete."
}
print_summary() {
echo ""
echo "──────────────────────────────────────────────────────────────────────"
echo " Summary"
echo "──────────────────────────────────────────────────────────────────────"
echo " Images: swapped to enterprise"
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " Storage: Postgres (data migrated from SQLite)"
[[ "$MIGRATE_POSTGRES" != "yes" ]] && echo " Storage: SQLite (unchanged)"
[[ "$ENABLE_FLOW" == "yes" ]] && echo " Traffic flow: enabled"
[[ "$ENABLE_FLOW" != "yes" ]] && echo " Traffic flow: disabled"
echo ""
echo " Generated files (next to your docker-compose.yml):"
echo " $OVERRIDE_FILE"
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " $ENTERPRISE_CONFIG_FILE"
echo " .env (license key + secrets, mode 600)"
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " backups/sqlite-pre-enterprise-*/ (SQLite backup)"
echo ""
echo " Tail logs:"
echo " $DOCKER_COMPOSE_COMMAND logs -f $COMBINED_SERVICE"
echo ""
echo "──────────────────────────────────────────────────────────────────────"
echo " To revert"
echo "──────────────────────────────────────────────────────────────────────"
echo " $DOCKER_COMPOSE_COMMAND down"
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
# Resolve project-prefixed volume names now (before override is removed).
local pg_volume data_volume_actual
pg_volume=$(resolve_data_volume "netbird_postgres")
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
echo " # Remove the Postgres volume FIRST, before deleting the override file:"
echo " docker volume rm $pg_volume"
echo " # Restore SQLite from the backup created during this run:"
echo " docker run --rm -v ${data_volume_actual}:/var/lib/netbird -v ${BACKUP_DIR}:/backup busybox sh -c 'cp -a /backup/. /var/lib/netbird/'"
fi
echo " rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
echo " # Remove migrate-to-enterprise.sh additions from .env (search for the timestamp marker)"
echo " $DOCKER_COMPOSE_COMMAND up -d"
echo "──────────────────────────────────────────────────────────────────────"
}
# ---------------------------------------------------------------------------
# Run
# ---------------------------------------------------------------------------
init_migration
apply_changes
print_summary

View File

@@ -497,7 +497,7 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),
@@ -610,10 +610,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peerID)
if err != nil {
return nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {

View File

@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
return srv
}
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -13,7 +13,7 @@ const (
reconnThreshold = 5 * time.Minute
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
metaChangeLimit = 5 // Number of reconnections with different metadata that triggers a ban of one peer
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
)
type lfConfig struct {
@@ -139,7 +139,7 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
state.lastSeen = now
}
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
h := fnv.New64a()
h.Write([]byte(meta.WtVersion))
@@ -147,6 +147,14 @@ func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return h.Sum64()
macs := uint64(0)
for _, na := range meta.NetworkAddresses {
for _, r := range na.Mac {
macs += uint64(r)
}
}
return h.Sum64() + macs
}

View File

@@ -164,7 +164,9 @@ func BenchmarkHashingMethods(b *testing.B) {
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
@@ -173,7 +175,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta)
resultString = builderString(meta, pubip)
}
})
@@ -181,7 +183,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta)
resultString = fnvHashToString(meta, pubip)
}
})
@@ -189,7 +191,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta)
resultUint = metaHash(meta, pubip)
}
})
@@ -197,20 +199,29 @@ func BenchmarkHashingMethods(b *testing.B) {
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta) string {
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
var b strings.Builder
b.Grow(estimatedSize)
@@ -224,10 +235,23 @@ func builderString(meta nbpeer.PeerSystemMeta) string {
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000

View File

@@ -33,8 +33,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
proxyauth "github.com/netbirdio/netbird/proxy/auth"
@@ -84,9 +82,6 @@ type ProxyServiceServer struct {
// Manager for users
usersManager users.Manager
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
idpManager idp.Manager
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
@@ -162,7 +157,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -171,7 +166,6 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
pkceVerifierStore: pkceStore,
peersManager: peersManager,
usersManager: usersManager,
idpManager: idpManager,
proxyManager: proxyMgr,
tokenChecker: tokenChecker,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
@@ -1708,7 +1702,22 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
@@ -1745,45 +1754,6 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}, nil
}
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
// user or peer ID, and peer name or user email.
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
// Resolve the principal: when the peer is linked to a user, the human is the
// principal so multiple peers owned by the same user share a single
// identity. Unlinked peers (machine agents) are their own principal keyed on
// peer.ID. displayIdentity is what upstream gateways tag spend with —
// user.Email when linked, peer.Name when not.
// If the peer isn't associated with a user, return the peer info directly.
if peer.UserID == "" {
return peer.ID, peer.Name
}
// Otherwise, if the peer is linked to a user, the user is the principal and
// if an IdP is available, we gather details on the user from it.
principalID := peer.UserID
displayIdentity := peer.Name
// Stored column first (cheap, but often empty for OIDC-provisioned users).
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
// IdP enrichment wins when available — the stored email column is a
// best-effort cache and is frequently empty for OIDC users. Enrichment
// failures must never fail the RPC; we simply keep the stored/peer identity.
if s.idpManager != nil {
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
displayIdentity = ud.Email
} else if uerr != nil {
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
}
}
return principalID, displayIdentity
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security

View File

@@ -3,19 +3,14 @@ package grpc
import (
"context"
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockReverseProxyManager struct {
@@ -142,52 +137,6 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
return user, nil, nil
}
// mockTunnelPeersManager implements only the two peers.Manager methods that
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
// panics if any unexpected method is invoked).
type mockTunnelPeersManager struct {
peers.Manager
peer *peer.Peer
peerErr error
groups []*types.Group
groupsErr error
}
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
return m.peer, m.peerErr
}
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
return m.peer, m.groups, m.groupsErr
}
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
// an IdP that knows nothing about the user.
type mockTunnelIdpManager struct {
idp.Manager
email string
hasData bool
err error
gotCalls int
gotMeta []idp.AppMetadata
}
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
m.gotCalls++
m.gotMeta = append(m.gotMeta, meta)
if m.err != nil {
return nil, m.err
}
if !m.hasData {
// This might not be a thing any of the actual IDP implementations do,
// i.e. return a nil value with no error, but it seems valuable to test
// that behavior here.
return nil, nil //nolint:nilnil
}
return &idp.UserData{ID: userID, Email: m.email}, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
@@ -405,163 +354,6 @@ func TestValidateUserGroupAccess(t *testing.T) {
}
}
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
// (IdP email -> stored User.Email -> peer.Name).
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
const (
domain = "app.example.com"
accountID = "account1"
peerID = "peer1"
peerName = "peer-display-name"
userID = "user1"
)
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
tests := []struct {
name string
peerUserID string
storedUsers map[string]*types.User
storedErr error
noIdP bool
idpEmail string
idpHasData bool
idpErr error
expectEmail string
expectUserID string
expectIdPHit bool
}{
{
name: "idp email wins over stored email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp returns empty email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "",
idpHasData: true,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp has no data",
peerUserID: userID,
storedUsers: storedUser,
idpHasData: false,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp errors",
peerUserID: userID,
storedUsers: storedUser,
idpErr: errors.New("idp unreachable"),
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when no idp manager",
peerUserID: userID,
storedUsers: storedUser,
noIdP: true,
expectEmail: "stored@example.com",
expectUserID: userID,
},
{
name: "idp email when stored email is empty",
peerUserID: userID,
storedUsers: storedUserNoEmail,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "idp email when stored user missing keeps peer.UserID as principal",
peerUserID: userID,
storedUsers: map[string]*types.User{},
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "unlinked peer uses peer name and never consults idp",
peerUserID: "",
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: peerName,
expectUserID: peerID,
expectIdPHit: false,
},
{
name: "linked peer with empty stored email and no idp falls back to peer name",
peerUserID: userID,
storedUsers: storedUserNoEmail,
noIdP: true,
expectEmail: peerName,
expectUserID: userID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := &service.Service{Domain: domain, AccountID: accountID}
server := &ProxyServiceServer{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
},
peersManager: &mockTunnelPeersManager{
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
},
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
}
var idpMock *mockTunnelIdpManager
if !tt.noIdP {
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
server.idpManager = idpMock
}
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
Domain: domain,
TunnelIp: "100.64.0.1",
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.True(t, resp.GetValid(), "expected access granted")
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
assert.Equal(t, tt.expectUserID, resp.GetUserId())
if idpMock != nil {
if tt.expectIdPHit {
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
require.Len(t, idpMock.gotMeta, 1)
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
} else {
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
}
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string

View File

@@ -254,7 +254,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return mapError(ctx, err)
}
metahashed := metaHash(peerMeta)
metahashed := metaHash(peerMeta, sRealIP)
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
@@ -306,7 +306,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
metahash := metaHash(peerMeta)
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
@@ -732,7 +732,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
}
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
metahashed := metaHash(peerMeta)
metahashed := metaHash(peerMeta, sRealIP)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
@@ -788,11 +788,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
ExtraDNSLabels: loginReq.GetDnsLabels(),
})
if err != nil {
if errors.Is(err, internalStatus.ErrNoAuthMethodProvided) {
log.WithContext(ctx).Tracef("failed logging in peer %s: %s", peerKey, err)
} else {
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
}
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(ctx, err)
}
@@ -1209,7 +1205,7 @@ func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*pr
return nil, msg
}
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()), realIP)
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
if err != nil {
return nil, mapError(ctx, err)
}
@@ -1258,10 +1254,7 @@ func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*prot
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
for _, postureCheck := range postureChecks {
check := toProtocolCheck(postureCheck)
if check != nil {
protoChecks = append(protoChecks, check)
}
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
}
return protoChecks
@@ -1285,9 +1278,5 @@ func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
}
}
if len(protoCheck.Files) == 0 {
return nil
}
return protoCheck
}

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)

View File

@@ -1889,12 +1889,12 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
// concurrent stream that started earlier loses the optimistic-lock race
// in MarkPeerConnected and bails without writing.
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP}, accountID)
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
if err := am.MarkPeerConnected(ctx, peerPubKey, accountID, syncTime.UnixNano(), netMap); err != nil {
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -1914,13 +1914,13 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
return nil
}
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
if err != nil {
return err
}
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP, UpdateAccountPeers: true}, accountID)
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil {
return err
}
@@ -2045,7 +2045,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain, email, nam
Extra: &types.ExtraSettings{
UserApprovalRequired: true,
},
LazyConnectionEnabled: true,
},
Onboarding: types.AccountOnboarding{
OnboardingFlowPending: true,

View File

@@ -62,7 +62,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -123,7 +123,7 @@ type Manager interface {
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)

View File

@@ -1323,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, accountID, sessionStartedAt, nmap)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, accountID, sessionStartedAt, nmap)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
}
// MarkPeerDisconnected mocks base method.
@@ -1586,17 +1586,17 @@ func (mr *MockManagerMockRecorder) SyncPeer(ctx, sync, accountID interface{}) *g
}
// SyncPeerMeta mocks base method.
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta, realIP net.IP) error {
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta, realIP)
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta)
ret0, _ := ret[0].(error)
return ret0
}
// SyncPeerMeta indicates an expected call of SyncPeerMeta.
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta, realIP interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta, realIP)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta)
}
// SyncUserJWTGroups mocks base method.

View File

@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1916,117 +1916,6 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
}
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
InactivityExpirationEnabled: true,
}, false)
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
PeerInactivityExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
// Establish a session so the matching-token disconnect is actually applied.
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
// Install the mock only now, so the assertion observes the disconnect, not
// the earlier connect.
scheduled := make(chan struct{}, 1)
manager.peerInactivityExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
select {
case scheduled <- struct{}{}:
default:
}
},
}
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer disconnected")
select {
case <-scheduled:
// expected: disconnect re-armed the inactivity expiry timer
case <-time.After(time.Second):
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
}
}
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
InactivityExpirationEnabled: true,
}, false)
require.NoError(t, err, "unable to add peer")
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
// stays disabled, so disconnect must not schedule anything.
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
PeerInactivityExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
scheduled := make(chan struct{}, 1)
manager.peerInactivityExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
select {
case scheduled <- struct{}{}:
default:
}
},
}
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer disconnected")
select {
case <-scheduled:
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
case <-time.After(200 * time.Millisecond):
// expected: nothing scheduled
}
}
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -2046,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -2067,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -2091,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node2SyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -2101,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
"SessionStartedAt should equal node2SyncTime token")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node1StaleSyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -2163,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, token, nil)
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
}()
}
@@ -2204,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -3326,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err

View File

@@ -41,7 +41,7 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
},
},
{
@@ -106,9 +106,11 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
change, mustContain, mustExclude := r.build(t, s, ctx)
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
for _, peerID := range mustExclude {
assert.NotContains(t, affected, peerID, "peer must not be affected")
for _, id := range mustContain {
assert.Contains(t, affected, id, "expected peer to be affected")
}
for _, id := range mustExclude {
assert.NotContains(t, affected, id, "peer must not be affected")
}
})
}

View File

@@ -251,9 +251,7 @@ func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSou
}
}
// A disabled sibling router routes to nobody, so updating a resource on its network
// must NOT refresh its peer (the enabled router carries the bridge instead).
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *testing.T) {
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
@@ -276,18 +274,13 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *tes
require.NoError(t, err)
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
enabledCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
settleAffectedUpdates(disabledCh, enabledCh)
settleAffectedUpdates(disabledCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, enabledCh)
peerShouldNotReceiveUpdate(t, disabledCh)
peerShouldReceiveUpdate(t, disabledCh)
close(done)
}()
@@ -305,7 +298,7 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *tes
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout")
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
}
}

View File

@@ -682,9 +682,6 @@ func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
}
// A disabled router in the snapshot routes to nobody, so it is skipped when the
// walk scans existing account data: a policy edit still folds the literal source
// group, but not the disabled router's peer.
func TestAffectedPeers_DisabledRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
@@ -697,13 +694,11 @@ func TestAffectedPeers_DisabledRouter(t *testing.T) {
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
assert.NotContains(t, affected, s.routerPeerID,
"a disabled router routes to nobody, so its peer must not be folded from snapshot data")
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
}
// A disabled resource in the snapshot is skipped: the policy edit still folds the
// literal source group, but the resource no longer bridges to its network's router.
func TestAffectedPeers_DisabledResource(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
@@ -715,9 +710,9 @@ func TestAffectedPeers_DisabledResource(t *testing.T) {
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
assert.NotContains(t, affected, s.routerPeerID,
"a disabled resource routes to nobody, so its network's router must not be folded from snapshot data")
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
}
func TestAffectedPeers_DisabledRule(t *testing.T) {

View File

@@ -96,54 +96,33 @@ func affectedGroupID(i int) string { return fmt.Sprintf("affected-grp-%d", i)
func affectedGroupName(i int) string { return fmt.Sprintf("AffectedGroup%d", i) }
func TestCollectGroupChange_PolicyLinked(t *testing.T) {
manager, s, accountID, peerIDs, groupIDs := setupAffectedPeersTest(t)
manager, s, accountID, _, groupIDs := setupAffectedPeersTest(t)
ctx := context.Background()
_, err := manager.SavePolicy(ctx, accountID, userID, &types.Policy{
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: peerIDs[0], Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypePeer},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
DestinationResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypeHost},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
Enabled: true,
Sources: []string{groupIDs[0]},
Destinations: []string{groupIDs[1]},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[1]})
groups, _ := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[0]})
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
assert.Empty(t, groups)
assert.Empty(t, directPeers)
}
func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
@@ -154,44 +133,20 @@ func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: peerIDs[4], Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypeHost},
DestinationResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
Enabled: true,
Sources: []string{groupIDs[0]},
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
Destinations: []string{groupIDs[1]},
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[4]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.ElementsMatch(t, directPeers, []string{peerIDs[3]})
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
assert.Empty(t, groups)
assert.Empty(t, directPeers)
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.Contains(t, directPeers, peerIDs[3])
}
func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T) {
@@ -213,7 +168,8 @@ func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.Empty(t, directPeers, "non-peer resources should not produce direct peer IDs")
}
@@ -338,7 +294,6 @@ func TestCollectGroupChange_NetworkRouterLinked(t *testing.T) {
AccountID: accountID,
PeerGroups: []string{groupIDs[0]},
Peer: peerIDs[3],
Enabled: true,
})
require.NoError(t, err)
@@ -369,7 +324,6 @@ func TestCollectGroupChange_NetworkRouterPeerOnlyNoGroups(t *testing.T) {
NetworkID: net1.ID,
AccountID: accountID,
Peer: peerIDs[4],
Enabled: true,
})
require.NoError(t, err)
@@ -419,11 +373,17 @@ func TestCollectGroupChange_MultipleEntities(t *testing.T) {
require.NoError(t, err)
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
assert.Contains(t, groups, groupIDs[0])
assert.Contains(t, groups, groupIDs[1])
assert.NotContains(t, groups, groupIDs[2])
assert.NotContains(t, groups, groupIDs[3])
assert.Empty(t, directPeers)
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[3]})
assert.ElementsMatch(t, groups, []string{groupIDs[2], groupIDs[3]})
assert.Contains(t, groups, groupIDs[2])
assert.Contains(t, groups, groupIDs[3])
assert.NotContains(t, groups, groupIDs[0])
assert.NotContains(t, groups, groupIDs[1])
assert.Empty(t, directPeers)
}
@@ -492,9 +452,8 @@ func TestResolveAffectedPeers_PolicyBetweenTwoGroups(t *testing.T) {
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
// peerIDs[2] is unrelated to the route; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
assert.Empty(t, result)
}
func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
@@ -515,7 +474,7 @@ func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
require.NoError(t, err)
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
}
func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
@@ -547,9 +506,8 @@ func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
// peerIDs[2] is in no policy; only its own map can change, so it refreshes itself.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
assert.Empty(t, result)
}
func TestResolveAffectedPeers_RouteWithDirectPeer(t *testing.T) {
@@ -606,9 +564,9 @@ func TestResolveAffectedPeers_RouteWithAccessControlGroups(t *testing.T) {
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
// peer3 is unrelated to the route; only its own map can change.
// peer3 is unrelated
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[3]})
assert.ElementsMatch(t, []string{peerIDs[3]}, result)
assert.Empty(t, result)
}
func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
@@ -629,7 +587,6 @@ func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
AccountID: accountID,
PeerGroups: []string{groupIDs[0]},
Peer: peerIDs[3],
Enabled: true,
})
require.NoError(t, err)
@@ -702,13 +659,9 @@ func TestResolveAffectedPeers_PeerInMultipleGroups(t *testing.T) {
}, true)
require.NoError(t, err)
// peer0 is in group0 AND group1, so both policies apply. A peer change folds
// only the changed peer plus the opposite side of each rule: group2 (peer2) via
// the group0 policy and group3 (peer3) via the group1 policy. peer1, a co-member
// of group1, is a sibling of the changed peer and must NOT refresh.
// peer0 is in group0 AND group1, so both policies apply
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[3]}, result)
assert.NotContains(t, result, peerIDs[1], "co-member of the changed peer's group must not refresh")
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
}
func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
@@ -744,7 +697,7 @@ func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
require.NoError(t, err)
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0], peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[1], peerIDs[3]}, result)
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
}
func TestResolveAffectedPeers_SharedGroupAcrossPolicyAndRoute(t *testing.T) {
@@ -901,9 +854,8 @@ func TestAffectedPeers_IsolatedPolicies(t *testing.T) {
assert.NotContains(t, result, peerIDs[0])
assert.NotContains(t, result, peerIDs[1])
// peerIDs[4] is in neither isolated policy; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[4]})
assert.ElementsMatch(t, []string{peerIDs[4]}, result)
assert.Empty(t, result)
}
func TestAffectedPeers_IsolatedRouteAndPolicy(t *testing.T) {
@@ -1025,13 +977,12 @@ func TestAffectedPeers_GroupUpdateOnlyAffectsLinkedPeers(t *testing.T) {
})
}
// A peer in no policy/route refreshes only itself — no other peer is affected.
func TestAffectedPeers_UnlinkedPeerChange_RefreshesSelfOnly(t *testing.T) {
func TestAffectedPeers_UnlinkedGroupChange_NoUpdates(t *testing.T) {
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
ctx := context.Background()
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.ElementsMatch(t, []string{peerIDs[0]}, result)
assert.Empty(t, result)
}
// TestAffectedPeers_PolicyChange_UnrelatedPeerNoUpdate verifies that creating/deleting a
@@ -1381,7 +1332,6 @@ func TestAffectedPeers_NetworkRouterUnlinkedPeerNoUpdate(t *testing.T) {
NetworkID: net1.ID,
AccountID: accountID,
PeerGroups: []string{"nr-grpA"},
Enabled: true,
})
require.NoError(t, err)
@@ -1805,9 +1755,7 @@ func TestCollectAffectedFromProxyServices_GroupContainingTargetPeerChanged(t *te
assert.Contains(t, directPeers, peerIDs[1], "target peer must be refreshed")
}
// A disabled service in the snapshot proxies nothing, so it is skipped: a changed
// target peer does not pull in the service's proxy peer.
func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing.T) {
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
ctx := context.Background()
@@ -1833,7 +1781,8 @@ func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
require.NoError(t, s.CreateService(ctx, svc))
_, directPeers := collectPeerChangeAffectedGroups(ctx, manager.Store, accountID, nil, []string{peerIDs[1]})
assert.NotContains(t, directPeers, peerIDs[0], "a disabled service proxies nothing, so its proxy peer must not be folded")
assert.Contains(t, directPeers, peerIDs[0], "disabled service should still trigger a refresh so peers are ready when re-enabled")
assert.Contains(t, directPeers, peerIDs[1], "disabled target should still trigger a refresh")
}
func TestCollectAffectedFromProxyServices_NonPeerTargetType(t *testing.T) {

View File

@@ -6,12 +6,7 @@
// and before a delete/removal severs the old state).
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
//
// Enabled handling differs by source. Disabled objects in the SNAPSHOT (existing
// account policies/resources/routers/routes/proxy services and their rules/targets)
// route to nobody and are skipped — they cannot affect any peer's map. Objects in
// the CHANGE itself are processed regardless of Enabled, so disabling one still
// refreshes the peers that lose access (the toggle is the observable change, and the
// update carries the oldnew state).
// Enabled is never consulted: toggling it is itself an observable change.
package affectedpeers
import (
@@ -66,8 +61,7 @@ func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snap
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
// collections a Change can touch, gated to what the walk needs.
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
// LinkGroups drive the same policy/route/dns walk as a changed group or peer.
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 || len(c.Resources) > 0
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
// the resource<->router bridge can fire for any of these
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
@@ -82,7 +76,7 @@ func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accoun
return err
}
}
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 {
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
if err := snap.loadDNS(ctx, s, accountID); err != nil {
return err
}
@@ -180,24 +174,6 @@ type Change struct {
// folded in — but only when the group is linked (an unlinked group has no map
// impact), matching how current members are handled.
RemovedPeersByGroup map[string][]string
// OutputPeerIDs are peers folded straight into the result without seeding their
// group memberships into the walk. Use for the peer whose group membership changed:
// the peer itself must refresh, but its OTHER groups did not change, so they must
// not be walked. Contrast ChangedPeerIDs, which seeds ALL of the peer's groups
// (correct when the peer's own attributes changed, e.g. IP/status).
OutputPeerIDs []string
// LinkGroups are groups used ONLY to match policies/routes/routers and walk to the
// OPPOSITE side — they are never expanded to their own members. Use this when a
// peer's group membership changed: pass the peer in ChangedPeerIDs and its
// group(s) here. The opposite side of the policies the group participates in
// refreshes, but the group's other members (siblings) do not — nothing changed for
// them. For an intra-group policy (A→A) the opposite side IS the group, so its
// members still refresh via the opposite-side fold, exactly when they genuinely
// gain/lose the changed peer. Unlike ChangedGroupIDs, a LinkGroup is not added to
// the output, so a one-sided membership change never wakes the whole group.
LinkGroups []string
}
func (c Change) isEmpty() bool {
@@ -210,9 +186,7 @@ func (c Change) isEmpty() bool {
len(c.Networks) == 0 &&
len(c.PostureCheckIDs) == 0 &&
len(c.DistributionGroupIDs) == 0 &&
len(c.RemovedPeersByGroup) == 0 &&
len(c.LinkGroups) == 0 &&
len(c.OutputPeerIDs) == 0
len(c.RemovedPeersByGroup) == 0
}
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
@@ -223,8 +197,8 @@ func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []
return nil
}
r := newResolver(ctx, snap, accountID, c)
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v linkGroups=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, c.LinkGroups, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
r.walk()
return r.expand()
}
@@ -242,84 +216,57 @@ func Collect(ctx context.Context, s store.Store, accountID string, c Change) (gr
}
r := newResolver(ctx, snap, accountID, c)
r.walk()
return setToSlice(r.affectedGroups), setToSlice(r.affectedPeers)
return setToSlice(r.groupSet), setToSlice(r.peerSet)
}
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
r := &resolver{
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
linkGroups: toSet(c.ChangedGroupIDs),
outputGroups: toSet(c.ChangedGroupIDs),
changedPeers: toSet(c.ChangedPeerIDs),
affectedGroups: make(map[string]struct{}),
affectedPeers: make(map[string]struct{}),
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
changedGroupSet: toSet(c.ChangedGroupIDs),
changedPeerSet: toSet(c.ChangedPeerIDs),
groupSet: make(map[string]struct{}),
peerSet: make(map[string]struct{}),
networkIDs: make(map[string]struct{}),
}
// LinkGroups match policies/routes to find the opposite side but are NOT output:
// they go into linkGroups only, never outputGroups, so their members never fold in.
addAll(r.linkGroups, c.LinkGroups)
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
r.seedChangedGroupsFromPeers()
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
return r
}
// seedChangedGroupsFromPeers adds each changed peer's groups to linkGroups so
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
// the group-driven walkers fire for memberships, not just direct peer references.
// These seeded groups are for MATCHING only — folding the changed entity's own
// side is gated on outputGroups (the caller-reported groups), so a seeded group
// never folds its whole membership; only the changed peer itself folds in.
func (r *resolver) seedChangedGroupsFromPeers() {
if len(r.changedPeers) == 0 {
if len(r.changedPeerSet) == 0 {
return
}
for groupID, members := range r.snap.groupPeers {
for pID := range r.changedPeers {
for pID := range r.changedPeerSet {
if _, ok := members[pID]; ok {
r.linkGroups[groupID] = struct{}{}
r.changedGroupSet[groupID] = struct{}{}
break
}
}
}
}
// policySide selects which side of a policy rule to walk.
type policySide int
const (
sideSource policySide = iota
sideDestination
)
func (s policySide) opposite() policySide {
if s == sideSource {
return sideDestination
}
return sideSource
}
// walk resolves affected peers in two buckets, by how far each change propagates.
//
// BOTH-SIDES — the rule itself changed (an explicit policy edit, or a policy whose
// posture check changed). Source AND destination refresh, so each such policy is
// walked on both sides.
//
// OPPOSITE-SIDE — an endpoint moved but no rule changed. For each policy the change
// touches we fold only the side AWAY from the change:
// - a changed peer/group sits ON a policy side -> fold the opposite side;
// - a changed router/resource/network sits on a NETWORK -> fold the SOURCE side of
// the policies whose destination reaches it (and the routers it implies).
//
// Routes, nameserver groups, DNS and embedded-proxy services distribute to their own
// member peers, outside the policy graph, and are folded here too.
func (r *resolver) walk() {
for _, policy := range r.bothSidesPolicies() {
r.foldPolicySide(policy, sideSource)
r.foldPolicySide(policy, sideDestination)
}
r.collectFromExplicitPolicies()
r.collectFromExplicitRoutes(r.change.Routes)
r.collectFromExplicitRouters(r.change.Routers)
r.collectFromExplicitResources(r.change.Resources)
r.collectFromExplicitNetworks(r.change.Networks)
r.collectFromPostureChecks(r.change.PostureCheckIDs)
if len(r.linkGroups) > 0 || len(r.changedPeers) > 0 {
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
// straight into groupSet so expand() maps them to members, without the policy/
// route walk that changedGroupSet would trigger.
addAll(r.groupSet, r.change.DistributionGroupIDs)
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
r.collectFromPolicies()
r.collectFromRoutes()
r.collectFromNameServers()
@@ -328,31 +275,7 @@ func (r *resolver) walk() {
r.collectFromProxyServices()
}
r.collectFromChangedRoutes(r.change.Routes)
r.collectFromChangedRouters(r.change.Routers)
r.collectFromChangedResources(r.change.Resources)
r.collectFromChangedNetworks(r.change.Networks)
// The explicitly changed peers always refresh their own maps. OnPeersUpdated only
// refreshes the resolver's output (it ignores the separately-passed changed peers),
// so the changed peer reaches its own new map only via here. An offline/deleted
// peer in the set is filtered downstream (filterConnectedAffectedPeers).
addAll(r.affectedPeers, setToSlice(r.changedPeers))
// OutputPeerIDs refresh themselves too, but unlike changedPeers their group
// memberships were not seeded into the walk (only the changed group was).
addAll(r.affectedPeers, r.change.OutputPeerIDs)
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
// straight into affectedGroups so expand() maps them to members, without the
// policy/route walk that linkGroups would trigger.
addAll(r.affectedGroups, r.change.DistributionGroupIDs)
}
// bothSidesPolicies are the policies whose rule changed: the explicitly edited ones
// plus those gated by a changed posture check. walk folds both their sides.
func (r *resolver) bothSidesPolicies() []*types.Policy {
policies := append([]*types.Policy(nil), r.change.Policies...)
return r.appendPoliciesForPostureChecks(policies, r.change.PostureCheckIDs)
r.collectResourceRouterBridge()
}
type resolver struct {
@@ -361,71 +284,27 @@ type resolver struct {
accountID string
change Change
// Inputs — what changed. Set once at construction, read-only during the walk
// (except linkGroups, which collectFromExplicitResources also seeds).
//
// linkGroups is the MATCH set: caller-changed groups the groups of changed
// peers changed-resource groups. A rule/route/router matches the change when
// one of its groups is here — used only to find the opposite side to fold.
//
// outputGroups is the FOLD-WHOLE-GROUP set: ONLY Change.ChangedGroupIDs. When a
// matched group is here, its whole membership is affected. A peer-seeded group
// is in linkGroups but NOT outputGroups, so it folds only the changed peer
// (changedPeers), never its siblings.
linkGroups map[string]struct{}
outputGroups map[string]struct{}
changedPeers map[string]struct{}
changedGroupSet map[string]struct{}
changedPeerSet map[string]struct{}
// Outputs — the answer. The only sets the walk accumulates into. affectedGroups
// is expanded to its member peers in expand().
affectedGroups map[string]struct{}
affectedPeers map[string]struct{}
groupSet map[string]struct{}
peerSet map[string]struct{}
matchedPolicies []*types.Policy
networkIDs map[string]struct{}
}
// policies returns the account's ENABLED policies from the snapshot. Disabled
// policies grant no access, so the walk skips them when scanning existing account
// data. Explicitly changed policies (Change.Policies, via bothSidesPolicies) are
// processed regardless of Enabled, so disabling one still refreshes its peers.
func (r *resolver) policies() []*types.Policy {
enabled := make([]*types.Policy, 0, len(r.snap.policies))
for _, policy := range r.snap.policies {
if policy != nil && policy.Enabled {
enabled = append(enabled, policy)
}
}
return enabled
}
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
// networkResources / networkRouters return the account's ENABLED resources/routers
// from the snapshot. Disabled objects route to nobody, so the walk skips them when
// it scans existing account data. The explicitly changed objects in the Change are
// processed regardless of Enabled (collectFromChanged*), so disabling one still
// refreshes the peers that lose access.
func (r *resolver) networkResources() []*resourceTypes.NetworkResource {
enabled := make([]*resourceTypes.NetworkResource, 0, len(r.snap.resources))
for _, resource := range r.snap.resources {
if resource.Enabled {
enabled = append(enabled, resource)
}
}
return enabled
}
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter {
enabled := make([]*routerTypes.NetworkRouter, 0, len(r.snap.routers))
for _, router := range r.snap.routers {
if router.Enabled {
enabled = append(enabled, router)
}
}
return enabled
}
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
seen := make(map[string]struct{})
var ids []string
for gID := range groups {
for gID := range groupSet {
for pID := range r.snap.groupPeers[gID] {
if _, ok := seen[pID]; ok {
continue
@@ -438,25 +317,25 @@ func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
}
func (r *resolver) expand() []string {
peerIDs := r.peerIDsForGroups(r.affectedGroups)
peerIDs := r.peerIDsForGroups(r.groupSet)
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
r.accountID, setToSlice(r.affectedGroups), len(peerIDs), setToSlice(r.affectedPeers))
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for id := range r.affectedPeers {
for id := range r.peerSet {
if _, ok := seen[id]; !ok {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
// Fold in removed peers only when their group is linked (in affectedGroups).
// Fold in removed peers only when their group is linked (in groupSet).
for groupID, removed := range r.change.RemovedPeersByGroup {
if _, linked := r.affectedGroups[groupID]; !linked {
if _, linked := r.groupSet[groupID]; !linked {
continue
}
for _, id := range removed {
@@ -472,349 +351,169 @@ func (r *resolver) expand() []string {
return peerIDs
}
// ruleSideGroups / ruleSideResource return the groups and the resource on the given
// side of a rule.
func ruleSideGroups(rule *types.PolicyRule, side policySide) []string {
if side == sideDestination {
return rule.Destinations
}
return rule.Sources
}
func ruleSideResource(rule *types.PolicyRule, side policySide) types.Resource {
if side == sideDestination {
return rule.DestinationResource
}
return rule.SourceResource
}
// foldPolicySide folds one side of a policy down to affected peers: its groups
// (resolved to members in expand) and its direct peer. When the side is the
// DESTINATION and references a network resource (directly or via a destination
// group's resources), it also folds the routers that serve that resource's network
// — a destination resource is reached through its routers. A resource on the SOURCE
// side routes to nobody (GetPoliciesForNetworkResource matches destinations only),
// so the router hop is destination-only.
func (r *resolver) foldPolicySide(policy *types.Policy, side policySide) {
if policy == nil {
return
}
for _, rule := range policy.Rules {
addAll(r.affectedGroups, ruleSideGroups(rule, side))
res := ruleSideResource(rule, side)
if res.Type == types.ResourceTypePeer && res.ID != "" {
r.affectedPeers[res.ID] = struct{}{}
}
}
if side == sideDestination {
r.foldRoutersForResources(r.policyDestinationResourceIDs(policy))
}
}
// appendPoliciesForPostureChecks appends every policy that references a changed
// posture check (a rule change, so walk both sides).
func (r *resolver) appendPoliciesForPostureChecks(policies []*types.Policy, postureCheckIDs []string) []*types.Policy {
if len(postureCheckIDs) == 0 {
return policies
}
ids := toSet(postureCheckIDs)
for _, policy := range r.policies() {
if !policyReferencesPostureChecks(policy, ids) || !policy.Enabled {
func (r *resolver) collectFromExplicitPolicies() {
for _, policy := range r.matchedPolicies {
if policy == nil {
continue
}
log.WithContext(r.ctx).Tracef("appendPoliciesForPostureChecks: policy %s (%s) references changed posture checks %v -> both-sides policy",
policy.ID, policy.Name, postureCheckIDs)
policies = append(policies, policy)
}
return policies
}
// collectFromPolicies folds, for every policy whose rule a changed group or peer
// touches, only the OPPOSITE side (down to peers, incl. destination routers), plus
// the changed entity's own side: the changed group's whole membership when the
// group itself changed (outputGroups), or the changed peer alone when matched via a
// peer-seeded group (never its co-members).
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue // a disabled rule grants no access
}
r.foldRuleSideIfChanged(policy, rule, sideSource)
r.foldRuleSideIfChanged(policy, rule, sideDestination)
}
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
}
}
// foldRuleSideIfChanged: when a changed group or direct peer sits on `side` of the
// rule, fold the opposite side fully (groups/peers + destination routers) and fold
// the changed entity's own side (the whole changed group, or the changed peer alone).
func (r *resolver) foldRuleSideIfChanged(policy *types.Policy, rule *types.PolicyRule, side policySide) {
nearGroups := ruleSideGroups(rule, side)
nearResource := ruleSideResource(rule, side)
matchedByGroup := anyInSet(nearGroups, r.linkGroups)
matchedByPeer := isDirectPeerInSet(nearResource, r.changedPeers)
if !matchedByGroup && !matchedByPeer {
return
}
// Opposite side, fully down to peers (a destination opposite also folds routers).
r.foldPolicySideForRule(policy, rule, side.opposite())
// Own side: fold the whole changed group's members only when the group itself
// changed (outputGroups). A peer-seeded or link-only group is not folded here —
// its siblings never refresh. The changed peers themselves are folded once, after
// the walk (see walk()).
for _, gID := range nearGroups {
if _, ok := r.outputGroups[gID]; ok {
r.affectedGroups[gID] = struct{}{}
}
}
// When the changed side IS a destination, the resources it targets are reached
// through their network's routers, so those routers refresh too (e.g. attaching a
// resource to a destination group, or a changed destination group/resource).
if side == sideDestination {
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
}
}
// foldPolicySideForRule folds one side of a single rule (groups + direct peer), and
// for a destination side the routers of that rule's destination resources.
func (r *resolver) foldPolicySideForRule(policy *types.Policy, rule *types.PolicyRule, side policySide) {
addAll(r.affectedGroups, ruleSideGroups(rule, side))
res := ruleSideResource(rule, side)
if res.Type == types.ResourceTypePeer && res.ID != "" {
r.affectedPeers[res.ID] = struct{}{}
}
if side == sideDestination {
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
}
}
// collectFromChangedRoutes folds an explicitly changed route's own groups and peer.
func (r *resolver) collectFromChangedRoutes(routes []*route.Route) {
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
for _, rt := range routes {
if rt == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.affectedPeers[rt.Peer] = struct{}{}
r.peerSet[rt.Peer] = struct{}{}
}
}
}
// collectFromChangedRouters: a changed router refreshes its OWN backing peer/groups
// (the changed entity) and the SOURCE side of every policy reaching a resource on
// its network (the router serves the whole network). Sibling routers on the network
// are independent and are NOT folded. Passing the old router state keeps a repointed
// router's previous backing affected without a post-commit read.
func (r *resolver) collectFromChangedRouters(routers []*routerTypes.NetworkRouter) {
// collectFromExplicitRouters folds changed routers' peers and marks their networks
// for the bridge. Passing the old router keeps a repointed router's previous peers
// affected without a post-commit read.
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
for _, router := range routers {
if router == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedRouters: changed router %s on network %s -> folding its own peerGroups=%v peer=%q + sources reaching network resources",
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.affectedGroups, router.PeerGroups)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.affectedPeers[router.Peer] = struct{}{}
r.peerSet[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
r.networkIDs[router.NetworkID] = struct{}{}
}
}
}
// collectFromChangedResources: a changed resource refreshes the SOURCE side of the
// policies targeting EXACTLY that resource — directly, or via one of the resource's
// own groups (oldnew across the change, so a now-detached group's sources still
// refresh) — plus the routers serving its network (the resource is reached through
// them). It does not touch sibling resources on the same network.
func (r *resolver) collectFromChangedResources(resources []*resourceTypes.NetworkResource) {
// collectFromExplicitResources marks changed resources' networks for the bridge and
// treats their group IDs as changed, so policies targeting the resource via a
// now-detached (old) group still refresh.
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
for _, resource := range resources {
if resource == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedResources: changed resource %s on network %s (groups %v) -> folding sources of policies targeting it + its network's routers",
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
resource.ID, resource.NetworkID, resource.GroupIDs)
r.foldPolicySourcesForResource(resource.ID, resource.GroupIDs)
addAll(r.changedGroupSet, resource.GroupIDs)
if resource.NetworkID != "" {
r.foldRoutersOnNetworks(map[string]struct{}{resource.NetworkID: {}})
r.networkIDs[resource.NetworkID] = struct{}{}
}
}
}
// foldPolicySourcesForResource folds the source side of every policy whose
// destination is the given resource — referenced directly, or via any of the given
// groups (the resource's own oldnew groups, which captures a detached group).
func (r *resolver) foldPolicySourcesForResource(resourceID string, groupIDs []string) {
groups := toSet(groupIDs)
for _, policy := range r.policies() {
if !policyTargetsResourceOrGroups(policy, resourceID, groups) {
continue
}
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResource: policy %s (%s) targets changed resource %s -> folding its source groups/peers", policy.ID, policy.Name, resourceID)
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
}
}
// policyTargetsResourceOrGroups reports whether a policy's destination is the given
// resource directly, or one of the given destination groups.
func policyTargetsResourceOrGroups(policy *types.Policy, resourceID string, groups map[string]struct{}) bool {
if policy == nil {
return false
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID == resourceID && resourceID != "" {
return true
}
if anyInSet(rule.Destinations, groups) {
return true
}
}
return false
}
// collectFromChangedNetworks: a changed network refreshes the SOURCE side of the
// policies reaching any of its resources, plus its routers. A network has no
// groups/peers of its own.
func (r *resolver) collectFromChangedNetworks(networks []*networkTypes.Network) {
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
// no groups/peers of its own.
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
for _, network := range networks {
if network == nil || network.ID == "" {
if network == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromChangedNetworks: changed network %s -> folding sources reaching its resources + its routers", network.ID)
resourceIDs := r.networkResourceIDs(network.ID)
r.foldPolicySourcesForResources(resourceIDs)
r.foldRoutersOnNetworks(map[string]struct{}{network.ID: {}})
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
if network.ID != "" {
r.networkIDs[network.ID] = struct{}{}
}
}
}
// foldPolicySourcesForResources folds the source groups/peers of every policy whose
// destination targets one of resourceIDs (directly or via a destination group).
func (r *resolver) foldPolicySourcesForResources(resourceIDs map[string]struct{}) {
if len(resourceIDs) == 0 {
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
if len(postureCheckIDs) == 0 {
return
}
ids := toSet(postureCheckIDs)
for _, policy := range r.policies() {
if r.policyTargetsResources(policy, resourceIDs) {
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResources: policy %s (%s) targets a changed resource -> folding its source groups/peers", policy.ID, policy.Name)
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
if !policyReferencesPostureChecks(policy, ids) {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
}
}
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
}
}
// collectFromRoutes folds, per matched route, the OPPOSITE side(s) fully and the
// matched side's own groups only on a whole-group change (outputGroups). A route has
// three peer sides — routing (Peer/PeerGroups), consumer (Groups) and ACL
// (AccessControlGroups) — that each refresh the others; the changed side's own group
// folds its siblings only when the group itself changed, never on a one-peer move.
func (r *resolver) collectFromRoutes() {
for _, rt := range r.snap.routes {
if !rt.Enabled {
continue // disabled routes route to nobody; skip existing account data
}
routing := anyInSet(rt.PeerGroups, r.linkGroups) || (rt.Peer != "" && isInSet(rt.Peer, r.changedPeers))
consumer := anyInSet(rt.Groups, r.linkGroups)
acl := anyInSet(rt.AccessControlGroups, r.linkGroups)
if !routing && !consumer && !acl {
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (routing=%t consumer=%t acl=%t) -> folding opposite sides; own side gated on outputGroups",
rt.ID, routing, consumer, acl)
r.foldRouteSide(rt.PeerGroups, routing)
r.foldRouteSide(rt.Groups, consumer)
r.foldRouteSide(rt.AccessControlGroups, acl)
// The single routing Peer folds when the routing side is the OPPOSITE of the
// match (consumer/acl need it), or when that very peer is the change.
if rt.Peer != "" && (consumer || acl || isInSet(rt.Peer, r.changedPeers)) {
r.affectedPeers[rt.Peer] = struct{}{}
}
}
}
// foldRouteSide folds a route side: when this side is the one that matched, fold its
// groups only on a whole-group change (outputGroups) so siblings of a single moved
// peer stay put; otherwise it is an opposite side and folds fully.
func (r *resolver) foldRouteSide(groups []string, matchedHere bool) {
if matchedHere {
r.foldOutputGroups(groups)
return
}
addAll(r.affectedGroups, groups)
}
// foldOutputGroups folds only the groups that the caller reported as wholly changed
// (outputGroups). Used for a matched object's OWN side, where a peer-seeded or
// link-only group must not pull in its siblings.
func (r *resolver) foldOutputGroups(groups ...[]string) {
for _, gs := range groups {
for _, gID := range gs {
if _, ok := r.outputGroups[gID]; ok {
r.affectedGroups[gID] = struct{}{}
}
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
}
}
}
func (r *resolver) collectFromNameServers() {
if len(r.linkGroups) == 0 {
if len(r.changedGroupSet) == 0 {
return
}
for _, ns := range r.snap.nsGroups {
if anyInSet(ns.Groups, r.linkGroups) {
// A nameserver group has no opposite side: a peer's DNS config depends only
// on its own membership, so a one-peer move refreshes that peer alone (folded
// elsewhere). Fold the referenced groups only on a whole-group change.
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a linked group -> folding its groups %v (outputGroups only)", ns.ID, ns.Groups)
r.foldOutputGroups(ns.Groups)
if anyInSet(ns.Groups, r.changedGroupSet) {
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
addAll(r.groupSet, ns.Groups)
}
}
}
func (r *resolver) collectFromDNSSettings() {
if len(r.linkGroups) == 0 || r.snap.dnsSettings == nil {
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
return
}
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
if _, ok := r.linkGroups[gID]; ok {
if _, ok := r.changedGroupSet[gID]; ok {
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
r.affectedGroups[gID] = struct{}{}
r.groupSet[gID] = struct{}{}
}
}
}
// collectFromNetworkRouters handles a changed group/peer that BACKS a router (the
// routing peer set moved): the router's own peers refresh and so do the sources of
// the policies reaching its network's resources. Sibling routers on the network are
// independent and are not folded.
func (r *resolver) collectFromNetworkRouters() {
for _, router := range r.networkRouters() {
matchedByGroup := anyInSet(router.PeerGroups, r.linkGroups)
matchedByPeer := router.Peer != "" && len(r.changedPeers) > 0 && isInSet(router.Peer, r.changedPeers)
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding its peerGroups=%v peer=%q (own groups on outputGroups) + sources reaching network resources",
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
// The backing PeerGroups are the matched (own) side: fold them only on a
// whole-group change so a one-peer move does not wake sibling backing peers. The
// opposite side (policy sources reaching the network) is folded below.
r.foldOutputGroups(router.PeerGroups)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.affectedPeers[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
r.peerSet[router.Peer] = struct{}{}
}
r.networkIDs[router.NetworkID] = struct{}{}
}
}
@@ -827,48 +526,42 @@ func (r *resolver) collectFromProxyServices() {
expanded := r.expandChangedPeersWithGroups()
for _, svc := range services {
if svc == nil || !svc.Enabled {
continue // a disabled service proxies nothing; skip existing account data
if svc == nil {
continue
}
proxyPeers := proxyByCluster[svc.ProxyCluster]
if len(proxyPeers) == 0 {
continue
}
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.linkGroups)
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
if !matchedByPeer && !matchedByAccessGroup {
continue
}
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets; access groups %v on outputGroups only",
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
for _, pid := range proxyPeers {
r.affectedPeers[pid] = struct{}{}
r.peerSet[pid] = struct{}{}
}
for _, target := range svc.Targets {
if !target.Enabled {
continue // a disabled target forwards nothing
}
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
r.affectedPeers[target.TargetId] = struct{}{}
r.peerSet[target.TargetId] = struct{}{}
}
}
// AccessGroups are the matched (own) side with no opposite to fold: a member's
// proxy access is self-contained, so a one-peer move refreshes that peer alone.
// Fold the groups only on a whole-group change.
r.foldOutputGroups(svc.AccessGroups)
addAll(r.groupSet, svc.AccessGroups)
}
}
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
if len(r.linkGroups) == 0 {
return r.changedPeers
if len(r.changedGroupSet) == 0 {
return r.changedPeerSet
}
ids := r.peerIDsForGroups(r.linkGroups)
ids := r.peerIDsForGroups(r.changedGroupSet)
if len(ids) == 0 {
return r.changedPeers
return r.changedPeerSet
}
merged := make(map[string]struct{}, len(r.changedPeers)+len(ids))
for id := range r.changedPeers {
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
for id := range r.changedPeerSet {
merged[id] = struct{}{}
}
for _, id := range ids {
@@ -877,36 +570,54 @@ func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
return merged
}
// foldRoutersForResources folds the routers serving the networks of the given
// resources (a destination resource is reached through its network's routers). It is
// the resource -> network -> router hop used by foldPolicySide for a destination.
func (r *resolver) foldRoutersForResources(resourceIDs map[string]struct{}) {
// collectResourceRouterBridge crosses between source peers and routing peers, which
// are reachable only via resource -> network -> router, not through the policy's own
// groups: source -> router (targeted resources' networks), then router -> source.
func (r *resolver) collectResourceRouterBridge() {
r.bridgeSourceToRouters()
r.bridgeRoutersToSources()
}
func (r *resolver) bridgeSourceToRouters() {
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
if len(resourceIDs) == 0 {
return
}
r.foldRoutersOnNetworks(r.resourceNetworkIDs(resourceIDs))
}
// ruleDestinationResourceIDs returns the destination resource IDs of a single rule:
// the direct DestinationResource plus the resources of its destination groups.
func (r *resolver) ruleDestinationResourceIDs(rule *types.PolicyRule) map[string]struct{} {
resourceIDs := make(map[string]struct{})
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
resourceIDs[rule.DestinationResource.ID] = struct{}{}
networkIDs := r.resourceNetworkIDs(resourceIDs)
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
setToSlice(resourceIDs), setToSlice(networkIDs))
for id := range networkIDs {
r.networkIDs[id] = struct{}{}
}
r.addGroupResourceIDs(toSet(rule.Destinations), resourceIDs)
return resourceIDs
}
// networkResourceIDs returns the IDs of all resources on the given network.
func (r *resolver) networkResourceIDs(networkID string) map[string]struct{} {
func (r *resolver) bridgeRoutersToSources() {
if len(r.networkIDs) == 0 {
return
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
setToSlice(r.networkIDs))
r.foldRoutersOnNetworks(r.networkIDs)
resourceIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if resource.NetworkID == networkID {
if _, ok := r.networkIDs[resource.NetworkID]; ok {
resourceIDs[resource.ID] = struct{}{}
}
}
return resourceIDs
if len(resourceIDs) == 0 {
return
}
for _, policy := range r.policies() {
if r.policyTargetsResources(policy, resourceIDs) {
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
collectPolicySources(policy, r.groupSet, r.peerSet)
}
}
}
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
@@ -916,9 +627,9 @@ func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.affectedGroups, router.PeerGroups)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.affectedPeers[router.Peer] = struct{}{}
r.peerSet[router.Peer] = struct{}{}
}
}
}
@@ -939,9 +650,6 @@ func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[
}
destGroupSet := make(map[string]struct{})
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
return true
}
@@ -1006,20 +714,44 @@ func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs
}
}
// collectPolicySources folds the source groups/peers of a snapshot policy's enabled
// rules (a disabled rule grants no access).
func collectPolicySources(policy *types.Policy, groups, peers map[string]struct{}) {
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
addAll(groups, rule.Sources)
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peers[rule.SourceResource.ID] = struct{}{}
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
addAll(groupSet, rule.Sources)
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
return true
}
}
return false
}
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
for _, id := range policy.SourcePostureChecks {
if _, ok := ids[id]; ok {
@@ -1044,7 +776,7 @@ func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, cha
}
}
for _, target := range svc.Targets {
if !target.Enabled || target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
continue
}
if _, ok := changedPeers[target.TargetId]; ok {

View File

@@ -10,8 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/types"
)
// policyGroupsAndPeers mirrors the both-sides extraction (RuleGroups + direct peers)
// the resolver folds in for a changed policy, for asserting the pure logic.
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
// direct peers) the resolver folds in, for asserting the pure logic.
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
peerSet := map[string]struct{}{}
for _, p := range policies {
@@ -19,14 +19,7 @@ func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []s
continue
}
groups = append(groups, p.RuleGroups()...)
for _, rule := range p.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
collectPolicyDirectPeers(p, peerSet)
}
for id := range peerSet {
peers = append(peers, id)
@@ -87,6 +80,26 @@ func TestChangeIsEmpty(t *testing.T) {
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
}
func TestPolicyReferencesGroups(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
}
func TestPolicyReferencesDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
}
func TestPolicyReferencesPostureChecks(t *testing.T) {
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
@@ -94,9 +107,24 @@ func TestPolicyReferencesPostureChecks(t *testing.T) {
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
}
func TestCollectPolicyDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
}, {
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
peerSet := map[string]struct{}{}
collectPolicyDirectPeers(policy, peerSet)
assert.Contains(t, peerSet, "p1")
assert.Contains(t, peerSet, "p2")
assert.NotContains(t, peerSet, "r1")
}
func TestCollectPolicySources(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Enabled: true,
Sources: []string{"g1"},
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
Destinations: []string{"g2"},

View File

@@ -520,12 +520,7 @@ func collectDeletableGroups(ctx context.Context, transaction store.Store, accoun
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var snap *affectedpeers.Snapshot
// A membership change affects only the peer itself and the opposite side of THIS
// group's policies — not the group's other members, and not the peer's other
// groups. LinkGroups walks only this group (matched, not expanded); OutputPeerIDs
// refreshes the peer without seeding its other group memberships. For an
// intra-group policy the opposite side is the group, so its members still refresh.
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
@@ -591,11 +586,10 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var snap *affectedpeers.Snapshot
// Same as GroupAddPeer: the removed peer and the opposite side of THIS group's
// policies refresh, not the group's other members or the peer's other groups. The
// peer is no longer in the group's index, but LinkGroups still drives the
// opposite-side walk, and OutputPeerIDs refreshes the removed peer itself.
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
change := affectedpeers.Change{
ChangedGroupIDs: []string{groupID},
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
}
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
@@ -606,6 +600,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
// The removed peer is carried in change.RemovedPeersByGroup and folded in
// only when the group is linked, so loading post-removal is correct.
var err error
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err

View File

@@ -217,7 +217,6 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
usersManager,
nil,
nil,
nil,
)
proxyService.SetServiceManager(&testServiceManager{store: testStore})

View File

@@ -220,7 +220,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
}
includeServiceUser, err := strconv.ParseBool(serviceUser)
log.WithContext(r.Context()).Tracef("Should include service user: %v", includeServiceUser)
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
return

View File

@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {

View File

@@ -39,7 +39,7 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
@@ -114,7 +114,7 @@ type MockAccountManager struct {
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
@@ -345,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, accountID, sessionStartedAt, nmap)
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
@@ -975,9 +975,9 @@ func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId str
}
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
if am.SyncPeerMetaFunc != nil {
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta, realIP)
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
}
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
}

View File

@@ -74,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
//
// Disconnects use MarkPeerDisconnected and require the session to match
// exactly; see PeerStatus.SessionStartedAt for the protocol.
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
start := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
@@ -102,6 +102,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
if am.geo != nil && realIP != nil {
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
}
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
return err
}
@@ -188,40 +192,27 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
}
}
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
} else if settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
return nil
}
// resolvePeerLocation looks up the geo location for realIP, returning nil when
// there is nothing to apply: geo disabled, no real IP, the IP is unchanged from
// what the peer already has, or the lookup failed. Geo lookups are skipped on
// same-IP reconnects since they are comparatively expensive. The returned value
// is applied by Peer.UpdateMetaIfNew so the change is persisted by its peer save.
func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *nbpeer.Peer, realIP net.IP) *nbpeer.Location {
if am.geo == nil || realIP == nil {
return nil
// updatePeerLocationIfChanged refreshes the geolocation on a separate
// row update, only when the connection IP actually changed. Geo lookups
// are expensive so we skip same-IP reconnects.
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
return
}
location, err := am.geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
return nil
return
}
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) && peer.Location.GeoNameID == location.City.GeonameID {
return nil
}
return &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: location.Country.ISOCode,
CityName: location.City.Names.En,
GeoNameID: location.City.GeonameID,
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
}
}
@@ -730,7 +721,7 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
// no auth method provided => reject access
return nil, nil, nil, false, status.ErrNoAuthMethodProvided
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
}
upperKey := strings.ToUpper(setupKey)
@@ -989,9 +980,10 @@ func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
var peer *nbpeer.Peer
var ipv6CapabilityChanged bool
var metaDiff nbpeer.MetaDiff
var updated, versionChanged, ipv6CapabilityChanged bool
var err error
var postureChecks []*posture.Checks
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -1019,16 +1011,25 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return status.NewPeerLoginExpiredError()
}
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
newLocation := am.resolvePeerLocation(ctx, peer, sync.RealIP)
metaDiff = peer.UpdateMetaIfNew(ctx, sync.Meta, newLocation)
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if metaDiff.Updated() {
if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return err
}
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
}
return nil
})
@@ -1036,11 +1037,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
@@ -1051,10 +1047,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
if requiresPeerUpdate(ctx, isStatusChanged, sync.UpdateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, metaDiff.VersionChanged(), metaDiff.HostnameChanged()) {
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1063,29 +1058,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return peer, nmap, resPostureChecks, dnsFwdPort, nil
}
func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, versionChanged, hostname bool) bool {
var reason string
switch {
case isStatusChanged:
reason = "status changed"
case updateAccountPeers:
reason = "update account peers"
case ipv6CapabilityChanged:
reason = "ipv6 capability changed"
case metaDiffAffectsPosture:
reason = "meta diff affects posture"
case versionChanged:
reason = "version changed"
case hostname:
reason = "hostname changed"
default:
return false
}
log.WithContext(ctx).Tracef("peer update required: %s", reason)
return true
}
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
// peer's own validated network map is bidirectional for policy and routing
// reachability, so when the peer stays valid and no source-posture gate is in
@@ -1094,8 +1066,8 @@ func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers
// metadata change that flips a posture result removes this peer from others'
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
// back to the resolver.
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaChangeAffectedPosture bool) []string {
if peerNotValid || metaChangeAffectedPosture {
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
if peerNotValid || (metaUpdated && hasPostureChecks) {
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
}
return affectedPeerIDsFromNetworkMap(nmap, peerID)
@@ -1152,7 +1124,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
var peer *nbpeer.Peer
var shouldStorePeer, shouldUpdatePeers bool
var shouldStorePeer bool
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1179,10 +1151,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if changed {
shouldStorePeer = true
shouldUpdatePeers = true
}
}
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
@@ -1204,15 +1180,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
peer.Meta = login.Meta
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, false, err
}
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, false, err
}
@@ -1222,7 +1190,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
if shouldUpdatePeers {
if isStatusChanged || shouldStorePeer {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
@@ -1318,22 +1286,12 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, nil, false, nil
}
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
if err != nil {
return nil, nil, false, err
}
@@ -1341,16 +1299,32 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, postureChecks, enableSSH, nil
}
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
for _, peerID := range peerGroupIDs {
groupIDsMap[peerID] = struct{}{}
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return false, err
}
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
for _, g := range peerGroups {
peerGroupIDs[g.ID] = struct{}{}
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
}
// getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
if len(policies) == 0 {
return nil, nil
}
@@ -1362,7 +1336,11 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
continue
}
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
if err != nil {
return nil, err
}
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
}
@@ -1375,19 +1353,29 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
}
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil {
return nil, err
}
for _, sourceGroup := range rule.Sources {
if slices.Contains(peerGroupIDs, sourceGroup) {
return policy.SourcePostureChecks
group, ok := sourceGroups[sourceGroup]
if !ok {
return nil, fmt.Errorf("failed to check peer in policy source group")
}
if slices.Contains(group.Peers, peerID) {
return policy.SourcePostureChecks, nil
}
}
}
return nil
return nil, nil
}
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO

View File

@@ -1,16 +1,12 @@
package peer
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sort"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -107,15 +103,6 @@ type Location struct {
GeoNameID uint // city level geoname id
}
// equal reports whether two locations match. ConnectionIP is a net.IP slice, so it uses
// IP.Equal, not ==.
func (l Location) equal(other Location) bool {
return l.CountryCode == other.CountryCode &&
l.CityName == other.CityName &&
l.GeoNameID == other.GeoNameID &&
l.ConnectionIP.Equal(other.ConnectionIP)
}
// NetworkAddress is the IP address with network and MAC address of a network interface
type NetworkAddress struct {
NetIP netip.Prefix `gorm:"serializer:json"`
@@ -175,7 +162,49 @@ type PeerSystemMeta struct { //nolint:revive
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
return len(metaDiff(p, other)) == 0
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
})
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
})
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
})
if !equalNetworkAddresses {
return false
}
sort.Slice(p.Files, func(i, j int) bool {
return p.Files[i].Path < p.Files[j].Path
})
sort.Slice(other.Files, func(i, j int) bool {
return other.Files[i].Path < other.Files[j].Path
})
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
})
if !equalFiles {
return false
}
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.KernelVersion == other.KernelVersion &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.OSVersion == other.OSVersion &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion &&
p.SystemSerialNumber == other.SystemSerialNumber &&
p.SystemProductName == other.SystemProductName &&
p.SystemManufacturer == other.SystemManufacturer &&
p.Environment.Cloud == other.Environment.Cloud &&
p.Environment.Platform == other.Environment.Platform &&
p.Flags.isEqual(other.Flags) &&
capabilitiesEqual(p.Capabilities, other.Capabilities)
}
func (p PeerSystemMeta) isEmpty() bool {
@@ -265,173 +294,26 @@ func (p *Peer) Copy() *Peer {
}
}
// UpdateMetaIfNew updates peer's system metadata and connection geo location if
// new information is provided. newLocation is the geo location resolved from the
// peer's current connection IP, or nil when there is nothing to apply (geo
// disabled, no real IP, or the IP is unchanged); the caller owns the expensive
// lookup and the same-IP guard. It returns a MetaDiff describing what changed;
// diff.Updated() reports whether the peer needs to be persisted.
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLocation *Location) MetaDiff {
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) {
if meta.isEmpty() {
return MetaDiff{}
return updated, versionChanged
}
versionChanged = p.Meta.WtVersion != meta.WtVersion
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion
}
effectiveLocation := p.Location
if newLocation != nil {
effectiveLocation = *newLocation
if p.Meta.isEqual(meta) {
return updated, versionChanged
}
diff := diffMeta(p.Meta, meta, p.Location, effectiveLocation)
if diff.Updated() {
p.Meta = meta
}
p.Location = effectiveLocation
if diff.Updated() {
log.WithContext(ctx).Debug(diff.LogSummary())
}
return diff
}
// MetaDiff holds a peer's full before/after state across a sync: both metas and both
// connection locations (the location lives on Peer, not PeerSystemMeta, but posture
// checks read it). Changed lists what moved, for logging and the persistence decision;
// the snapshots let a posture check be replayed against old and new. Everything is derived
// from these fields, so there are no parallel per-field flags to keep in sync.
type MetaDiff struct {
OldMeta PeerSystemMeta
NewMeta PeerSystemMeta
OldLocation Location
NewLocation Location
Changed []string
}
// Updated reports whether anything changed and the peer must be persisted. diffMeta fills
// Changed in the pass that builds the diff, so this is a length check, not a re-comparison.
// Pointer receiver: MetaDiff embeds two metas, so copying it per call is wasteful.
func (d *MetaDiff) Updated() bool {
return len(d.Changed) != 0
}
// VersionChanged reports whether the WireGuard client version changed (a client upgrade).
func (d *MetaDiff) VersionChanged() bool {
return d.OldMeta.WtVersion != d.NewMeta.WtVersion
}
// HostnameChanged reports whether the peer's hostname changed.
func (d *MetaDiff) HostnameChanged() bool {
return d.OldMeta.Hostname != d.NewMeta.Hostname
}
// LogSummary renders the changed fields as a single human-readable line.
func (d *MetaDiff) LogSummary() string {
return fmt.Sprintf("peer meta updated, %d field(s) changed: %s",
len(d.Changed), strings.Join(d.Changed, ", "))
}
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
return diffMeta(oldMeta, newMeta, Location{}, Location{}).Changed
}
// diffMeta snapshots a peer's old and new state and records a Changed entry per field that
// moved. It is the single source of truth for the comparison: isEqual is an empty Changed
// list, so the log line and the persistence decision can never disagree.
func diffMeta(oldMeta, newMeta PeerSystemMeta, oldLocation, newLocation Location) MetaDiff {
d := MetaDiff{OldMeta: oldMeta, NewMeta: newMeta, OldLocation: oldLocation, NewLocation: newLocation}
add := func(field string, oldVal, newVal any) {
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
add("hostname", oldMeta.Hostname, newMeta.Hostname)
}
if oldMeta.GoOS != newMeta.GoOS {
add("goos", oldMeta.GoOS, newMeta.GoOS)
}
if oldMeta.Kernel != newMeta.Kernel {
add("kernel", oldMeta.Kernel, newMeta.Kernel)
}
if oldMeta.KernelVersion != newMeta.KernelVersion {
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
}
if oldMeta.Core != newMeta.Core {
add("core", oldMeta.Core, newMeta.Core)
}
if oldMeta.Platform != newMeta.Platform {
add("platform", oldMeta.Platform, newMeta.Platform)
}
if oldMeta.OS != newMeta.OS {
add("os", oldMeta.OS, newMeta.OS)
}
if oldMeta.OSVersion != newMeta.OSVersion {
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
}
if oldMeta.WtVersion != newMeta.WtVersion {
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
}
if oldMeta.UIVersion != newMeta.UIVersion {
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
}
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
}
if oldMeta.SystemProductName != newMeta.SystemProductName {
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
}
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
}
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
}
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
}
if !oldMeta.Flags.isEqual(newMeta.Flags) {
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
}
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
if !oldLocation.equal(newLocation) {
add("connection_ip", oldLocation.ConnectionIP, newLocation.ConnectionIP)
}
return d
}
// sameMultiset reports whether two slices contain the same elements with the
// same multiplicity, ignoring order. The element type is the comparison key, so
// every field participates in equality.
func sameMultiset[T comparable](a, b []T) bool {
if len(a) != len(b) {
return false
}
counts := make(map[T]int, len(a))
for _, v := range a {
counts[v]++
}
for _, v := range b {
counts[v]--
if counts[v] == 0 {
delete(counts, v)
}
}
return len(counts) == 0
p.Meta = meta
updated = true
return updated, versionChanged
}
// GetLastLogin returns the last login time of the peer.

View File

@@ -1,113 +0,0 @@
package peer
import (
"net/netip"
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
// metaDiffExtraEntries accounts for PeerSystemMeta fields that metaDiff does not
// map 1:1 to a single diff entry. Today the only such field is Environment, which
// is exploded into two checks (Cloud, Platform) and therefore yields one extra
// entry beyond its single struct field. If you teach metaDiff to explode another
// field into N entries, bump this by N-1; if you collapse a field, lower it.
const metaDiffExtraEntries = 1
// TestMetaDiff_CoversAllFields fully populates a PeerSystemMeta with non-zero
// values and diffs it against the zero value, then asserts metaDiff emits exactly
// one entry per exported field (plus metaDiffExtraEntries for fields it explodes).
//
// The expected count is derived from the struct via reflection, so adding a field
// to PeerSystemMeta raises the expectation automatically — but the actual diff
// only grows if metaDiff was taught to compare the new field. A mismatch means
// someone changed the struct without updating metaDiff (or this test's
// extra-entry accounting), which is exactly what we want to catch.
func TestMetaDiff_CoversAllFields(t *testing.T) {
var full PeerSystemMeta
exported := populateAll(t, reflect.ValueOf(&full).Elem())
require.NotZero(t, exported, "expected PeerSystemMeta to expose fields")
diff := metaDiff(PeerSystemMeta{}, full)
require.Len(t, diff, exported+metaDiffExtraEntries,
"metaDiff entry count no longer matches PeerSystemMeta's fields: a field was "+
"likely added or removed without updating metaDiff (or metaDiffExtraEntries). "+
"diff was: %v", diff)
require.False(t, full.isEqual(PeerSystemMeta{}),
"isEqual must report a fully-populated meta as different from the zero value")
}
// TestFlags_isEqualChecksEveryField guards the one field that the count-based
// TestMetaDiff_CoversAllFields cannot: metaDiff collapses all of Flags into a
// single "flags" diff entry, so a new Flags field that Flags.isEqual forgets to
// compare would not change the diff count. This flips each Flags field on its own
// and asserts Flags.isEqual notices, so adding a Flags field without comparing it
// fails here.
func TestFlags_isEqualChecksEveryField(t *testing.T) {
typ := reflect.TypeOf(Flags{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
require.Equal(t, reflect.Bool, f.Type.Kind(),
"Flags.%s is not a bool; extend this test to set it non-zero", f.Name)
var a, b Flags
reflect.ValueOf(&b).Elem().Field(i).SetBool(true)
require.False(t, a.isEqual(b), "Flags.isEqual ignores field %s", f.Name)
}
}
// populateAll sets every exported field of the struct to a deterministic non-zero
// value, recursing into nested structs and the element type of struct slices so
// that each leaf differs from zero. It returns the number of exported fields on
// the top-level struct. netip.Prefix is treated as an opaque leaf (it has no
// settable exported fields and is comparable with ==).
func populateAll(t *testing.T, v reflect.Value) int {
t.Helper()
typ := v.Type()
exported := 0
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
if f.PkgPath != "" { // unexported
continue
}
exported++
setNonZero(t, v.Field(i))
}
return exported
}
// setNonZero assigns a deterministic non-zero value to a field based on its kind,
// recursing into nested structs and populating one element of slice fields.
func setNonZero(t *testing.T, field reflect.Value) {
t.Helper()
if field.Type() == reflect.TypeOf(netip.Prefix{}) {
field.Set(reflect.ValueOf(netip.MustParsePrefix("10.0.0.0/24")))
return
}
switch field.Kind() {
case reflect.String:
field.SetString("non-zero")
case reflect.Bool:
field.SetBool(true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.SetInt(7)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(7)
case reflect.Float32, reflect.Float64:
field.SetFloat(7)
case reflect.Struct:
populateAll(t, field)
case reflect.Slice:
s := reflect.MakeSlice(field.Type(), 1, 1)
setNonZero(t, s.Index(0))
field.Set(s)
default:
t.Fatalf("unhandled field kind %s; extend setNonZero", field.Kind())
}
}

View File

@@ -49,7 +49,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/geolocation"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
@@ -2894,141 +2893,3 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
require.NoError(t, err, "renaming to unique FQDN should succeed")
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
}
// fakeGeo is a configurable geolocation.Geolocation implementation for tests. It
// returns a record built from the configured city geoname id, or an error when set.
type fakeGeo struct {
geoNameID uint
isoCode string
cityName string
err error
}
func (g *fakeGeo) Lookup(net.IP) (*geolocation.Record, error) {
if g.err != nil {
return nil, g.err
}
record := &geolocation.Record{}
record.City.GeonameID = g.geoNameID
record.City.Names.En = g.cityName
record.Country.ISOCode = g.isoCode
return record, nil
}
func (g *fakeGeo) GetAllCountries() ([]geolocation.Country, error) { return nil, nil }
func (g *fakeGeo) GetCitiesByCountry(string) ([]geolocation.City, error) { return nil, nil }
func (g *fakeGeo) Stop() error { return nil }
func TestResolvePeerLocation(t *testing.T) {
realIP := net.ParseIP("203.0.113.10")
tests := []struct {
name string
geo geolocation.Geolocation
peer *nbpeer.Peer
realIP net.IP
want *nbpeer.Location
wantNil bool
}{
{
name: "no geo configured returns nil",
geo: nil,
peer: &nbpeer.Peer{ID: "p1"},
realIP: realIP,
wantNil: true,
},
{
name: "nil real IP returns nil",
geo: &fakeGeo{geoNameID: 100},
peer: &nbpeer.Peer{ID: "p1"},
realIP: nil,
wantNil: true,
},
{
name: "lookup error returns nil",
geo: &fakeGeo{err: fmt.Errorf("lookup boom")},
peer: &nbpeer.Peer{ID: "p1"},
realIP: realIP,
wantNil: true,
},
{
name: "same IP and same geoname returns nil",
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
peer: &nbpeer.Peer{
ID: "p1",
Location: nbpeer.Location{
ConnectionIP: realIP,
GeoNameID: 100,
},
},
realIP: realIP,
wantNil: true,
},
{
name: "same IP but changed geoname returns location",
geo: &fakeGeo{geoNameID: 200, isoCode: "US", cityName: "City B"},
peer: &nbpeer.Peer{
ID: "p1",
Location: nbpeer.Location{
ConnectionIP: realIP,
GeoNameID: 100,
},
},
realIP: realIP,
want: &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: "US",
CityName: "City B",
GeoNameID: 200,
},
},
{
name: "different IP returns location",
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
peer: &nbpeer.Peer{
ID: "p1",
Location: nbpeer.Location{
ConnectionIP: net.ParseIP("198.51.100.7"),
GeoNameID: 100,
},
},
realIP: realIP,
want: &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: "US",
CityName: "City A",
GeoNameID: 100,
},
},
{
name: "no prior location returns location",
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
peer: &nbpeer.Peer{ID: "p1"},
realIP: realIP,
want: &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: "US",
CityName: "City A",
GeoNameID: 100,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
am := &DefaultAccountManager{geo: tt.geo}
got := am.resolvePeerLocation(context.Background(), tt.peer, tt.realIP)
if tt.wantNil {
assert.Nil(t, got, "resolved location should be nil")
return
}
require.NotNil(t, got, "resolved location should not be nil")
assert.True(t, tt.want.ConnectionIP.Equal(got.ConnectionIP), "connection IP should match")
assert.Equal(t, tt.want.CountryCode, got.CountryCode, "country code should match")
assert.Equal(t, tt.want.CityName, got.CityName, "city name should match")
assert.Equal(t, tt.want.GeoNameID, got.GeoNameID, "geoname id should match")
})
}
}

View File

@@ -1,202 +0,0 @@
package posture
import (
"context"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
// diffFrom builds a MetaDiff from the old/new snapshots AffectsPosture replays against.
func diffFrom(oldMeta, newMeta nbpeer.PeerSystemMeta, oldLoc, newLoc nbpeer.Location) *nbpeer.MetaDiff {
return &nbpeer.MetaDiff{
OldMeta: oldMeta,
NewMeta: newMeta,
OldLocation: oldLoc,
NewLocation: newLoc,
}
}
func checks(def ChecksDefinition) []*Checks {
return []*Checks{{Checks: def}}
}
func TestAffectsPosture_NilDiff(t *testing.T) {
assert.False(t, AffectsPosture(context.Background(), nil, checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
})))
}
func TestAffectsPosture_NBVersion(t *testing.T) {
c := checks(ChecksDefinition{NBVersionCheck: &NBVersionCheck{MinVersion: "1.2.0"}})
tests := []struct {
name string
oldVer, newVer string
want bool
}{
{"both above min, no flip", "1.3.0", "1.4.0", false},
{"both below min, no flip", "1.0.0", "1.1.0", false},
{"crosses up below->above", "1.1.0", "1.3.0", true},
{"crosses down above->below", "1.3.0", "1.1.0", true},
{"unparsable old only -> flip", "garbage", "1.3.0", true},
{"unparsable both -> no flip", "garbage", "junk", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: tt.oldVer},
nbpeer.PeerSystemMeta{WtVersion: tt.newVer},
nbpeer.Location{}, nbpeer.Location{},
)
assert.Equal(t, tt.want, AffectsPosture(context.Background(), diff, c))
})
}
}
func TestAffectsPosture_OSVersion_KernelBumpWithinMin(t *testing.T) {
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "5.0.0"},
}})
// Kernel moves but stays above the minimum: verdict stays pass -> not affected.
withinMin := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.15.0-arch2"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), withinMin, c))
// Kernel drops below the minimum: verdict flips pass -> fail -> affected.
crossesDown := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0-arch1"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), crossesDown, c))
}
func TestAffectsPosture_OSVersion_GoOSSwitchFlipsVerdict(t *testing.T) {
// Only Linux is constrained. An OS outside the switch (freebsd) passes; switching to a
// failing linux kernel flips the verdict pass -> fail.
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0.0"},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "freebsd"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_GoOSSwitchFlipsVerdict(t *testing.T) {
// Process runs at a linux path. Switching GoOS to windows (no WindowsPath configured)
// flips the verdict.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
files := []nbpeer.File{{Path: "/usr/bin/foo", ProcessIsRunning: true}}
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: files},
nbpeer.PeerSystemMeta{GoOS: "windows", Files: files},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_UnrelatedFileChange(t *testing.T) {
// A tracked process stays running while an unrelated file is added: the verdict does
// not move, so posture is not affected.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
}},
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
{Path: "/usr/bin/bar", ProcessIsRunning: true},
}},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_GeoLocation(t *testing.T) {
c := checks(ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{{CountryCode: "DE"}},
}})
// Moving within allowed countries keeps the verdict; moving out flips it.
stayAllowed := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE", CityName: "Berlin"},
nbpeer.Location{CountryCode: "DE", CityName: "Munich"},
)
assert.False(t, AffectsPosture(context.Background(), stayAllowed, c))
moveOut := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE"},
nbpeer.Location{CountryCode: "FR"},
)
assert.True(t, AffectsPosture(context.Background(), moveOut, c))
}
func TestAffectsPosture_PeerNetworkRange_ConnectionIP(t *testing.T) {
// The check reads the connection IP. Moving out of the allowed range flips the verdict;
// moving within it does not.
_, allowed, _ := net.ParseCIDR("10.0.0.0/8")
c := checks(ChecksDefinition{PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{netip.MustParsePrefix(allowed.String())},
}})
movesOutOfRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
)
assert.True(t, AffectsPosture(context.Background(), movesOutOfRange, c))
staysInRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("10.9.9.9")},
)
assert.False(t, AffectsPosture(context.Background(), staysInRange, c))
}
func TestAffectsPosture_IrrelevantFieldChange(t *testing.T) {
// Hostname changes but no check reads it: not affected even with checks present.
c := checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
GeoLocationCheck: &GeoLocationCheck{Action: CheckActionAllow, Locations: []Location{{CountryCode: "DE"}}},
})
diff := diffFrom(
nbpeer.PeerSystemMeta{Hostname: "old", WtVersion: "1.5.0"},
nbpeer.PeerSystemMeta{Hostname: "new", WtVersion: "1.5.0"},
nbpeer.Location{CountryCode: "DE"}, nbpeer.Location{CountryCode: "DE"},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_NoChecks(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: "1.0.0"},
nbpeer.PeerSystemMeta{WtVersion: "2.0.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, nil))
}

View File

@@ -7,8 +7,6 @@ import (
"regexp"
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
@@ -53,46 +51,6 @@ type Checks struct {
Checks ChecksDefinition `gorm:"serializer:json"`
}
// AffectsPosture reports whether the change in diff flips the verdict of any check. It
// replays each check against the peer's old and new state and compares verdicts, so a
// change that moves a field but stays the right side of a threshold (e.g. a kernel bump
// still above the minimum) does not force a re-evaluation. See verdictChanged for how an
// evaluation error counts.
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
if diff == nil {
return false
}
oldPeer := nbpeer.Peer{Meta: diff.OldMeta, Location: diff.OldLocation}
newPeer := nbpeer.Peer{Meta: diff.NewMeta, Location: diff.NewLocation}
for _, c := range checks {
for _, check := range c.GetChecks() {
if verdictChanged(ctx, check, oldPeer, newPeer) {
return true
}
}
}
return false
}
// verdictChanged replays check against old and new state and reports whether the verdict
// differs. Like callers, it treats an evaluation error as deny: two errors are the same
// verdict (no change), an error on one side only is a flip.
func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Peer) bool {
oldPass, oldErr := check.Check(ctx, oldPeer)
newPass, newErr := check.Check(ctx, newPeer)
oldVerdict := oldPass && (oldErr == nil)
newVerdict := newPass && (newErr == nil)
changed := oldVerdict != newVerdict
log.WithContext(ctx).Tracef("posture check %s replay: verdict %t -> %t (changed=%t), errs: %v -> %v",
check.Name(), oldVerdict, newVerdict, changed, oldErr, newErr)
return changed
}
// ChecksDefinition contains definition of actual check
type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"`

Some files were not shown because too many files have changed in this diff Show More