mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-14 03:49:54 +00:00
Compare commits
340 Commits
feature/pr
...
ui-refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81dbecb896 | ||
|
|
a8462e3f9b | ||
|
|
c6d5136953 | ||
|
|
b35ca9fde3 | ||
|
|
0853cec437 | ||
|
|
9328558dbb | ||
|
|
13b4bf93b9 | ||
|
|
0d950d46f3 | ||
|
|
4854e5d370 | ||
|
|
cdcdd6a44f | ||
|
|
7648aa7015 | ||
|
|
4b705d472c | ||
|
|
8a4e686098 | ||
|
|
7eaea03bc9 | ||
|
|
eb788cad88 | ||
|
|
439b85c584 | ||
|
|
11281b681b | ||
|
|
6e9f4797e9 | ||
|
|
bbba18fc96 | ||
|
|
d70b3a3fa4 | ||
|
|
dac2ca4088 | ||
|
|
4e1e7f9518 | ||
|
|
9049974f26 | ||
|
|
f8e3ac6d92 | ||
|
|
96d31c3a5a | ||
|
|
bada2b5b78 | ||
|
|
32be58cb24 | ||
|
|
aaa5dbb606 | ||
|
|
4fc125dd38 | ||
|
|
727e6d3004 | ||
|
|
4a79c24792 | ||
|
|
6b2ae1c34c | ||
|
|
43175e0730 | ||
|
|
b598274424 | ||
|
|
ee912e176a | ||
|
|
6e23ed4da7 | ||
|
|
b067544c8a | ||
|
|
10d84b758f | ||
|
|
0e4d0128b6 | ||
|
|
21f1142355 | ||
|
|
9ecc083139 | ||
|
|
efd874efac | ||
|
|
5877880789 | ||
|
|
4427aaa31f | ||
|
|
0ce3fbf5af | ||
|
|
3967864172 | ||
|
|
296d3f124b | ||
|
|
adf1fe1858 | ||
|
|
1108808ab1 | ||
|
|
db371a0263 | ||
|
|
1412b06999 | ||
|
|
ca6f6d88cb | ||
|
|
e298747203 | ||
|
|
cba4a8a63b | ||
|
|
93e068f753 | ||
|
|
94065a8058 | ||
|
|
d7263a6be9 | ||
|
|
64199209cf | ||
|
|
166c6118e2 | ||
|
|
1710868a09 | ||
|
|
7798b7cf14 | ||
|
|
179966b000 | ||
|
|
48265a0143 | ||
|
|
9a76507b14 | ||
|
|
c3a0c1beeb | ||
|
|
45095818bc | ||
|
|
1f74ee9c78 | ||
|
|
562a538e91 | ||
|
|
50b26a21fd | ||
|
|
fefd0da7bf | ||
|
|
d1f3d88f0d | ||
|
|
a71ef1bab0 | ||
|
|
16743c4ce5 | ||
|
|
3c0a9c314b | ||
|
|
072d789463 | ||
|
|
8d05fe07bf | ||
|
|
61da51ed2e | ||
|
|
60c86c63aa | ||
|
|
4cee07bef5 | ||
|
|
5bebecc427 | ||
|
|
3dbd96b172 | ||
|
|
6fe35cae83 | ||
|
|
88bd1f91a8 | ||
|
|
acfd680560 | ||
|
|
0b8aae4566 | ||
|
|
daf9a74d8f | ||
|
|
8af90e40d5 | ||
|
|
3f989f69cb | ||
|
|
53d43980ad | ||
|
|
49df24b18c | ||
|
|
c5611dd766 | ||
|
|
a4ad93008b | ||
|
|
101e04f9fb | ||
|
|
710d5c6182 | ||
|
|
7538a9a133 | ||
|
|
5f7657b95e | ||
|
|
27873866c2 | ||
|
|
18348e1491 | ||
|
|
9569ac2081 | ||
|
|
0b484133b2 | ||
|
|
5df570feb8 | ||
|
|
ed4d823755 | ||
|
|
cedfa2ebf7 | ||
|
|
8b03c96851 | ||
|
|
b830a45333 | ||
|
|
b0d8ac6489 | ||
|
|
558769e671 | ||
|
|
fb6138a3ba | ||
|
|
b111c38b7c | ||
|
|
f54121ebfa | ||
|
|
122d172f33 | ||
|
|
0b19a99693 | ||
|
|
0309f992ad | ||
|
|
1ad2d90d3b | ||
|
|
93a1547871 | ||
|
|
04ab9b5bad | ||
|
|
61431801ea | ||
|
|
02e3cb9987 | ||
|
|
7a78b9df8a | ||
|
|
1416a2e160 | ||
|
|
88db1724bf | ||
|
|
d0d7252c24 | ||
|
|
9dc9e7184e | ||
|
|
1985caf993 | ||
|
|
16570b3223 | ||
|
|
967235e964 | ||
|
|
7d876571da | ||
|
|
e6a624dcee | ||
|
|
bee92f5fcd | ||
|
|
f4914fdfcc | ||
|
|
2cdc6ef1c6 | ||
|
|
3279b705fe | ||
|
|
e94a4cbce5 | ||
|
|
c1db8ab0ab | ||
|
|
2bf945e745 | ||
|
|
4556d52a60 | ||
|
|
51b243bdfa | ||
|
|
e09bc8894d | ||
|
|
55c1f44fb0 | ||
|
|
ac8d417c12 | ||
|
|
dccc0ebe4b | ||
|
|
35498c572a | ||
|
|
cda621bb27 | ||
|
|
d57b30f8d5 | ||
|
|
d82b950718 | ||
|
|
3bd058d425 | ||
|
|
0082f51830 | ||
|
|
e4420b1f96 | ||
|
|
a5635f8825 | ||
|
|
966fbec119 | ||
|
|
f693d268b4 | ||
|
|
09f4109b01 | ||
|
|
ad7d7fa881 | ||
|
|
b84c7618e7 | ||
|
|
ec5da43d73 | ||
|
|
a8ad73d2d9 | ||
|
|
a241112a1d | ||
|
|
e62dff0f66 | ||
|
|
5cecca2c23 | ||
|
|
0e83d2ad94 | ||
|
|
004a305e46 | ||
|
|
c77e5cef85 | ||
|
|
13179081d2 | ||
|
|
2d3c8fc555 | ||
|
|
61aa3a53ed | ||
|
|
80d6df6260 | ||
|
|
53bbc2d551 | ||
|
|
d9f0189b57 | ||
|
|
91e0520f27 | ||
|
|
67a1f3c4fe | ||
|
|
b6d20edfeb | ||
|
|
18d0019332 | ||
|
|
ecee7df5d8 | ||
|
|
1d783c33d9 | ||
|
|
b14feef1d7 | ||
|
|
0935a5675d | ||
|
|
4818599a93 | ||
|
|
f8c107b087 | ||
|
|
d624c2db74 | ||
|
|
513ecd456c | ||
|
|
8f957ff41a | ||
|
|
598fcbd817 | ||
|
|
17a365926d | ||
|
|
577ce6deb5 | ||
|
|
580cfa0dc5 | ||
|
|
8d4f35352f | ||
|
|
85029898a5 | ||
|
|
c3aeb5be15 | ||
|
|
df61f22d96 | ||
|
|
32df29bbd4 | ||
|
|
0a458ead8b | ||
|
|
aab8274b1a | ||
|
|
d3b660afba | ||
|
|
341848b1ae | ||
|
|
414e7815e4 | ||
|
|
ef6b4f7538 | ||
|
|
a7b26e3c0d | ||
|
|
42534b24c5 | ||
|
|
2aea1f7bb5 | ||
|
|
620233a7ac | ||
|
|
1c15e9976b | ||
|
|
f04e2bada8 | ||
|
|
1d88faf66f | ||
|
|
84093af1f0 | ||
|
|
34a4744565 | ||
|
|
b79b62bee4 | ||
|
|
bec4eb326a | ||
|
|
8748f3810d | ||
|
|
1c5254cb31 | ||
|
|
3f8cd29006 | ||
|
|
ca48de549e | ||
|
|
5b71a4f2ad | ||
|
|
741ce8581d | ||
|
|
6b44d65cac | ||
|
|
f84b1df857 | ||
|
|
c24349e4f1 | ||
|
|
7f7bee630f | ||
|
|
4e0eb9f2d4 | ||
|
|
38a367e0cd | ||
|
|
78fb15e327 | ||
|
|
35e58a2796 | ||
|
|
a6278936af | ||
|
|
32f62f3ed8 | ||
|
|
7fae703a27 | ||
|
|
f468f15a30 | ||
|
|
5bdccfe8f4 | ||
|
|
cccb0e9230 | ||
|
|
9d8eb76746 | ||
|
|
1ebb507cbb | ||
|
|
5411fa4350 | ||
|
|
17cae1a75c | ||
|
|
c0b0eeb6ab | ||
|
|
d32721d7fc | ||
|
|
288f8dec08 | ||
|
|
db8c9a0e30 | ||
|
|
505fcc7f7a | ||
|
|
0fe8764707 | ||
|
|
c0e7c61c4b | ||
|
|
e4eedbe18f | ||
|
|
fc1db63fc3 | ||
|
|
d841a6aa07 | ||
|
|
258e7ec038 | ||
|
|
1932b76f5b | ||
|
|
d33b841a33 | ||
|
|
df1935da6d | ||
|
|
eb6be5a2f3 | ||
|
|
209f14fc2f | ||
|
|
2bd56ecf67 | ||
|
|
67988c2407 | ||
|
|
53b2fb8dc1 | ||
|
|
803144e569 | ||
|
|
c0cd88a3d0 | ||
|
|
6c9b821bf0 | ||
|
|
83030dbbd6 | ||
|
|
1c8a6e3798 | ||
|
|
74ea03da9b | ||
|
|
77fdf23a50 | ||
|
|
1f4ed5c8ef | ||
|
|
e1bf362675 | ||
|
|
af40ee52f8 | ||
|
|
4988f2aa68 | ||
|
|
e3efaa5e59 | ||
|
|
100d25a062 | ||
|
|
04b4330393 | ||
|
|
c8e18585c6 | ||
|
|
1931a2c8a8 | ||
|
|
108d43e702 | ||
|
|
842ef0d657 | ||
|
|
439f44c6b4 | ||
|
|
b5a970155b | ||
|
|
686e0d97f2 | ||
|
|
0c287b6f4d | ||
|
|
f7f5946910 | ||
|
|
7a9f5a734f | ||
|
|
1aae067aaa | ||
|
|
28a7eba756 | ||
|
|
8841b950a2 | ||
|
|
0c2702c0d7 | ||
|
|
b43a09a1c7 | ||
|
|
595dfbb6f1 | ||
|
|
7f560df9be | ||
|
|
09052949a2 | ||
|
|
9aef31ff53 | ||
|
|
08f52f4517 | ||
|
|
18e3b5dd32 | ||
|
|
f3f9704c6f | ||
|
|
4c3d4effbd | ||
|
|
3953fee5a4 | ||
|
|
adeaa49cda | ||
|
|
2c5d52a1bf | ||
|
|
70a755fbae | ||
|
|
559da5d5b9 | ||
|
|
614ee11ac7 | ||
|
|
85080afa59 | ||
|
|
a5cc8da054 | ||
|
|
a4fd5a78b4 | ||
|
|
062a183e4e | ||
|
|
a2be41caf8 | ||
|
|
5b70989e3e | ||
|
|
d324a5ff48 | ||
|
|
debb558aa3 | ||
|
|
cce80f8276 | ||
|
|
05ee4e52b8 | ||
|
|
bb2bf673a0 | ||
|
|
91c745e5e8 | ||
|
|
68c38247f1 | ||
|
|
8b8f38de1b | ||
|
|
2b272e74c8 | ||
|
|
e6cbf30415 | ||
|
|
490b60ad0e | ||
|
|
553be144b4 | ||
|
|
c3f9514182 | ||
|
|
a8812d5fb1 | ||
|
|
6f93cf6ac3 | ||
|
|
18909390c2 | ||
|
|
b3eb5f2453 | ||
|
|
dc02542a9e | ||
|
|
0c136fffb9 | ||
|
|
fffb9dd219 | ||
|
|
93275f9052 | ||
|
|
dd9c15072f | ||
|
|
4c743bc03d | ||
|
|
2e61b42e92 | ||
|
|
3f8de2a149 | ||
|
|
bc609c3ae7 | ||
|
|
e3994d0c99 | ||
|
|
ba6e10cef3 | ||
|
|
ce53981b55 | ||
|
|
a69037630b | ||
|
|
df58935cc0 | ||
|
|
a1743dbf9b | ||
|
|
f9771de3f5 | ||
|
|
bfe19fa542 | ||
|
|
d07f25fc49 | ||
|
|
670b0f66ac | ||
|
|
15d73a2edd | ||
|
|
88a2bf582d | ||
|
|
0148d926d5 | ||
|
|
8f16a19b8f | ||
|
|
504dceedf3 |
18
.coderabbit.yaml
Normal file
18
.coderabbit.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: en-US
|
||||
reviews:
|
||||
profile: chill
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
poem: false
|
||||
review_status: true
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
path_filters:
|
||||
- "!**/*.tsx"
|
||||
- "!**/*.ts"
|
||||
- "!**/*.js"
|
||||
- "!**/*.svg"
|
||||
chat:
|
||||
auto_reply: true
|
||||
@@ -6,7 +6,6 @@ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
iptables=1.8.9-2 \
|
||||
libgl1-mesa-dev=22.3.6-1+deb12u1 \
|
||||
xorg-dev=1:7.7+23 \
|
||||
libayatana-appindicator3-dev=0.5.92-1 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& go install -v golang.org/x/tools/gopls@latest
|
||||
|
||||
10
.github/workflows/golang-test-darwin.yml
vendored
10
.github/workflows/golang-test-darwin.yml
vendored
@@ -45,7 +45,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
|
||||
# which fails to compile until the frontend has been built. The Wails UI
|
||||
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
|
||||
# before goreleaser.
|
||||
# `go list -e` lets the listing succeed even though the embed fails to
|
||||
# resolve; the grep then drops the broken package by path. Without -e,
|
||||
# go list aborts with empty stdout and `go test` falls back to the repo
|
||||
# root, which has no Go files.
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
|
||||
17
.github/workflows/golang-test-linux.yml
vendored
17
.github/workflows/golang-test-linux.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
@@ -145,7 +145,7 @@ jobs:
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
@@ -158,7 +158,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
|
||||
# which fails to compile until the frontend has been built. The Wails UI
|
||||
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
|
||||
# before goreleaser.
|
||||
# `go list -e` lets the listing succeed even though the embed fails to
|
||||
# resolve; the grep then drops the broken package by path. Without -e,
|
||||
# go list aborts with empty stdout and `go test` falls back to the repo
|
||||
# root, which has no Go files.
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
@@ -168,7 +176,6 @@ jobs:
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,client
|
||||
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -229,7 +236,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -e -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
|
||||
9
.github/workflows/golang-test-windows.yml
vendored
9
.github/workflows/golang-test-windows.yml
vendored
@@ -65,8 +65,15 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- name: Generate test script
|
||||
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
|
||||
# which fails to compile until the frontend has been built. The Wails UI
|
||||
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
|
||||
# before goreleaser.
|
||||
# `go list -e` lets the listing succeed even though the embed fails to
|
||||
# resolve; the Where-Object pipeline then drops the broken package by
|
||||
# path. Without -e, go list aborts with empty stdout.
|
||||
run: |
|
||||
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||
$packages = go list -e ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } | Where-Object { $_ -notmatch '/client/ui' }
|
||||
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||
|
||||
21
.github/workflows/golangci-lint.yml
vendored
21
.github/workflows/golangci-lint.yml
vendored
@@ -22,7 +22,15 @@ jobs:
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
# Non-English UI translations trip codespell on real foreign words
|
||||
# (de: "Sie", "oder", "ist"). Only en/common.json is the source of
|
||||
# truth that should be spell-checked. List each translated locale
|
||||
# dir below and add new ones as languages are added under
|
||||
# client/ui/i18n/locales/. Single-star globs are matched per path
|
||||
# segment by codespell and behave the same across versions; the
|
||||
# recursive "**" form did not take effect with the codespell shipped
|
||||
# by this action.
|
||||
skip: go.mod,go.sum,*/proxy/web/*,*pnpm-lock.yaml,*package-lock.json,*/locales/de/*,*/locales/hu/*
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -54,7 +62,16 @@ jobs:
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Stub Wails frontend bundle
|
||||
# client/ui/main.go has //go:embed all:frontend/dist. The
|
||||
# directory is produced by `pnpm run build` and is gitignored, so
|
||||
# lint-only runs (no frontend toolchain) need a placeholder file
|
||||
# for the embed pattern to match.
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p client/ui/frontend/dist
|
||||
touch client/ui/frontend/dist/.embed-placeholder
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
|
||||
79
.github/workflows/release.yml
vendored
79
.github/workflows/release.yml
vendored
@@ -194,9 +194,9 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
run: goversioninfo -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
@@ -356,8 +356,18 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
|
||||
- name: Set up pnpm
|
||||
uses: pnpm/action-setup@v3
|
||||
with:
|
||||
version: 11
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev gcc-mingw-w64-x86-64
|
||||
|
||||
- name: Decode GPG signing key
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
@@ -376,10 +386,16 @@ jobs:
|
||||
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Install wails3 CLI
|
||||
# Version derived from go.mod so the binding generator always matches
|
||||
# the wails runtime the binary links against.
|
||||
run: |
|
||||
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
|
||||
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
run: goversioninfo -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
@@ -447,6 +463,20 @@ jobs:
|
||||
run: go mod tidy
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
- name: Set up pnpm
|
||||
uses: pnpm/action-setup@v3
|
||||
with:
|
||||
version: 11
|
||||
- name: Install wails3 CLI
|
||||
# Version derived from go.mod so the binding generator always matches
|
||||
# the wails runtime the binary links against.
|
||||
run: |
|
||||
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
|
||||
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
@@ -534,23 +564,6 @@ jobs:
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
|
||||
destination: ${{ env.downloadPath }}\mesa3d.7z
|
||||
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
|
||||
|
||||
- name: Move opengl32.dll into dist (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
@@ -573,6 +586,28 @@ jobs:
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Set up Go for wails3 CLI
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Install wails3 CLI
|
||||
# Version derived from go.mod so the bootstrapper payload always
|
||||
# matches the wails runtime the binary links against.
|
||||
shell: bash
|
||||
run: |
|
||||
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
|
||||
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
|
||||
|
||||
- name: Stage WebView2 bootstrapper for installers
|
||||
# Both client/installer.nsis and client/netbird.wxs reference
|
||||
# client/MicrosoftEdgeWebview2Setup.exe. wails3 writes it there.
|
||||
# The signing pipeline (netbirdio/sign-pipelines) does the same
|
||||
# step for release builds; this mirrors it for PR sanity testing.
|
||||
shell: bash
|
||||
run: wails3 generate webview2bootstrapper -dir client
|
||||
|
||||
- name: Build NSIS installer
|
||||
shell: pwsh
|
||||
env:
|
||||
|
||||
2
.github/workflows/wasm-build-validation.yml
vendored
2
.github/workflows/wasm-build-validation.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
|
||||
@@ -114,6 +114,16 @@ linters:
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1012"
|
||||
# client/ui/main.go uses //go:embed all:frontend/dist; the
|
||||
# directory is populated by `pnpm build` in the release pipeline
|
||||
# and missing at lint time, so the embed parses to "no matching
|
||||
# files found" — surfaced by golangci-lint's typecheck pre-pass.
|
||||
# Suppress just that one diagnostic; the rest of the package
|
||||
# (services/, tray.go, grpc.go, ...) still gets linted normally.
|
||||
- linters:
|
||||
- typecheck
|
||||
path: client/ui/main\.go
|
||||
text: "pattern all:frontend/dist"
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
|
||||
@@ -196,6 +196,7 @@ nfpms:
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
vendor: NetBird
|
||||
id: netbird_deb
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
@@ -210,6 +211,7 @@ nfpms:
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
vendor: NetBird
|
||||
id: netbird_rpm
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
|
||||
before:
|
||||
hooks:
|
||||
# Bindings are gitignored; regenerate before the frontend build so
|
||||
# the @wailsio/runtime Vite plugin can resolve them (vite refuses to
|
||||
# build without them).
|
||||
- sh -c 'cd client/ui && wails3 generate bindings -clean=true -ts'
|
||||
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
|
||||
|
||||
builds:
|
||||
- id: netbird-ui
|
||||
dir: client/ui
|
||||
@@ -61,6 +70,8 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
vendor: NetBird
|
||||
id: netbird_ui_deb
|
||||
package_name: netbird-ui
|
||||
builds:
|
||||
@@ -70,16 +81,20 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/build/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/assets/netbird.png
|
||||
- src: client/ui/build/linux/netbird.desktop
|
||||
dst: /usr/share/applications/org.wails.netbird.desktop
|
||||
- src: client/ui/build/appicon.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
- libgtk-4-1
|
||||
- libwebkitgtk-6.0-4
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
homepage: https://netbird.io/
|
||||
license: BSD-3-Clause
|
||||
vendor: NetBird
|
||||
id: netbird_ui_rpm
|
||||
package_name: netbird-ui
|
||||
builds:
|
||||
@@ -89,12 +104,14 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/build/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/assets/netbird.png
|
||||
- src: client/ui/build/linux/netbird.desktop
|
||||
dst: /usr/share/applications/org.wails.netbird.desktop
|
||||
- src: client/ui/build/appicon.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
- gtk4
|
||||
- webkitgtk6.0
|
||||
rpm:
|
||||
signature:
|
||||
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
|
||||
before:
|
||||
hooks:
|
||||
# Bindings are gitignored; regenerate before the frontend build so
|
||||
# the @wailsio/runtime Vite plugin can resolve them (vite refuses to
|
||||
# build without them).
|
||||
- sh -c 'cd client/ui && wails3 generate bindings -clean=true -ts'
|
||||
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
|
||||
|
||||
builds:
|
||||
- id: netbird-ui-darwin
|
||||
dir: client/ui
|
||||
@@ -20,8 +29,6 @@ builds:
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird-ui-darwin
|
||||
|
||||
@@ -291,8 +291,6 @@ go test -exec sudo ./...
|
||||
```
|
||||
> On Windows use a powershell with administrator privileges
|
||||
|
||||
> Non-GTK environments will need the `libayatana-appindicator3-dev` (debian/ubuntu) package installed
|
||||
|
||||
## Checklist before submitting a PR
|
||||
As a critical network service and open-source project, we must enforce a few things before submitting the pull-requests:
|
||||
- Keep functions as simple as possible, with a single purpose
|
||||
|
||||
14
Makefile
14
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
|
||||
.PHONY: lint lint-all lint-install setup-hooks
|
||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||
|
||||
# Install golangci-lint locally if needed
|
||||
@@ -25,15 +25,3 @@ setup-hooks:
|
||||
@git config core.hooksPath .githooks
|
||||
@chmod +x .githooks/pre-push
|
||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||
|
||||
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
|
||||
# Runs as a normal user with no sudo and leaves host networking untouched.
|
||||
test-unit:
|
||||
@go test -tags devcert -timeout 10m ./...
|
||||
|
||||
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
|
||||
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
|
||||
# Narrow the run with env vars, e.g.:
|
||||
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
|
||||
test-privileged:
|
||||
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...
|
||||
|
||||
@@ -3,14 +3,12 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -87,73 +85,6 @@ var persistenceCmd = &cobra.Command{
|
||||
RunE: setSyncResponsePersistence,
|
||||
}
|
||||
|
||||
var debugConfigCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Example: " netbird debug config",
|
||||
Short: "Dump the effective configuration",
|
||||
Long: "Prints the daemon's resolved configuration (after applying defaults, file, env, CLI input, and MDM policy overrides) as JSON. Includes the list of MDM-managed fields.",
|
||||
RunE: debugConfigDump,
|
||||
}
|
||||
|
||||
// debugConfigDump implements `netbird debug config`. It resolves the
|
||||
// active profile, queries the daemon for the effective configuration
|
||||
// via GetConfig, and prints the resulting GetConfigResponse as JSON
|
||||
// (via protojson with EmitUnpopulated=true so the output is stable
|
||||
// across runs and includes zero-valued fields).
|
||||
//
|
||||
// Useful for verifying MDM enforcement end-to-end: the response's
|
||||
// mDMManagedFields array is the single source of truth for "which
|
||||
// fields is the daemon currently enforcing from the MDM source", and
|
||||
// every config field side-by-side with that list confirms the merge
|
||||
// result. Secrets in the response (e.g. PreSharedKey) are already
|
||||
// redacted by the daemon-side handler.
|
||||
func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %v", err)
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
// Use protojson so well-known fields render correctly; emit defaults so
|
||||
// the operator sees every field even when zero/empty.
|
||||
m := protojson.MarshalOptions{Multiline: true, Indent: " ", EmitUnpopulated: true}
|
||||
out, err := m.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
cmd.Println(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
// debugBundle requests the daemon to create a debug bundle and prints
|
||||
// the resulting local file path and, if uploaded, the uploaded file
|
||||
// key. It uses the package flags (anonymize, system info, log file
|
||||
// count, CLI version, optional upload URL) to configure the bundle
|
||||
// request. Returns an error if the RPC fails or if the daemon reports
|
||||
// an upload failure reason.
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
|
||||
)
|
||||
|
||||
var kubernetesCmd = &cobra.Command{
|
||||
Use: "kubernetes",
|
||||
Short: "Kubernetes cluster commands.",
|
||||
Long: "Kubernetes cluster commands.",
|
||||
}
|
||||
|
||||
var kubernetesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
RunE: kubernetesList,
|
||||
Short: "List Kubernetes clusters.",
|
||||
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
|
||||
}
|
||||
|
||||
var kubernetesWriteKubeconfigCmd = &cobra.Command{
|
||||
Use: "write-kubeconfig",
|
||||
RunE: kubernetesWriteKubeconfig,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Short: "Write kubeconfig for a Kubernetes cluster.",
|
||||
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
|
||||
}
|
||||
|
||||
func init() {
|
||||
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
|
||||
}
|
||||
|
||||
func kubernetesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
cmd.Println("No Kubernetes clusters available.")
|
||||
return nil
|
||||
}
|
||||
cmd.Println("Available Kubernetes clusters:")
|
||||
for _, k := range kcs {
|
||||
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
|
||||
kubeconfigPath, err := resolveKubeconfigPath(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clusterName := args[0]
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
|
||||
}
|
||||
if len(kcs) > 1 {
|
||||
return fmt.Errorf("too many Kubernetes clusters returned")
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kcs[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type kubernetesCluster struct {
|
||||
name string
|
||||
url *url.URL
|
||||
version string
|
||||
}
|
||||
|
||||
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
resolver := net.Resolver{
|
||||
// Required so both DNS records are returned.
|
||||
// https://github.com/golang/go/issues/17093
|
||||
PreferGo: true,
|
||||
}
|
||||
|
||||
kcs := []kubernetesCluster{}
|
||||
attempted := map[string]struct{}{}
|
||||
for _, peer := range peers {
|
||||
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, fqdn := range fqdns {
|
||||
if _, ok := attempted[fqdn]; ok {
|
||||
continue
|
||||
}
|
||||
attempted[fqdn] = struct{}{}
|
||||
comps := strings.Split(fqdn, ".")
|
||||
if len(comps) < 2 {
|
||||
continue
|
||||
}
|
||||
if comps[1] != KubernetesDNSSuffix {
|
||||
continue
|
||||
}
|
||||
if nameFilter != "" && nameFilter != comps[0] {
|
||||
continue
|
||||
}
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
|
||||
if err != nil {
|
||||
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
|
||||
continue
|
||||
}
|
||||
kc := kubernetesCluster{
|
||||
name: comps[0],
|
||||
url: clusterURL,
|
||||
version: clusterVersion,
|
||||
}
|
||||
if nameFilter != "" {
|
||||
return []kubernetesCluster{kc}, nil
|
||||
}
|
||||
kcs = append(kcs, kc)
|
||||
}
|
||||
}
|
||||
return kcs, nil
|
||||
}
|
||||
|
||||
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
|
||||
clusterURL, err := url.Parse("https://" + fqdn)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionURL, err := clusterURL.Parse("/version")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
|
||||
}
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionData := map[string]string{}
|
||||
err = json.Unmarshal(b, &versionData)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
version, ok := versionData["gitVersion"]
|
||||
if !ok {
|
||||
return nil, "", errors.New("no version found in response")
|
||||
}
|
||||
return clusterURL, version, nil
|
||||
}
|
||||
|
||||
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
|
||||
if cmd.Flags().Changed("kubeconfig") {
|
||||
path, err := cmd.Flags().GetString("kubeconfig")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
if env := os.Getenv("KUBECONFIG"); env != "" {
|
||||
return env, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".kube", "config"), nil
|
||||
}
|
||||
|
||||
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := yaml.Unmarshal(b, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = map[string]any{
|
||||
"apiVersion": "v1",
|
||||
"kind": "Config",
|
||||
}
|
||||
}
|
||||
|
||||
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
|
||||
"name": kc.name,
|
||||
"cluster": map[string]any{
|
||||
"server": kc.url.String(),
|
||||
"insecure-skip-tls-verify": true,
|
||||
},
|
||||
})
|
||||
cfg["users"] = appendWithName(cfg["users"], map[string]any{
|
||||
"name": "netbird",
|
||||
"user": map[string]any{
|
||||
"token": "none",
|
||||
},
|
||||
})
|
||||
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
|
||||
"name": kc.name,
|
||||
"context": map[string]any{
|
||||
"cluster": kc.name,
|
||||
"user": "netbird",
|
||||
"namespace": "default",
|
||||
},
|
||||
})
|
||||
cfg["current-context"] = kc.name
|
||||
|
||||
out, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendWithName(data any, add map[string]any) any {
|
||||
if data == nil {
|
||||
return []any{add}
|
||||
}
|
||||
v, ok := data.([]any)
|
||||
if !ok {
|
||||
return []any{add}
|
||||
}
|
||||
i := slices.IndexFunc(v, func(item any) bool {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return m["name"] == add["name"]
|
||||
})
|
||||
if i == -1 {
|
||||
return append(v, add)
|
||||
}
|
||||
v[i] = add
|
||||
return v
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFingerprintClusters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
//nolint: errcheck
|
||||
w.Write([]byte(`{"gitVersion": "foobar"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, srv.URL, clusterURL.String())
|
||||
require.Equal(t, "foobar", clusterVersion)
|
||||
}
|
||||
|
||||
func TestResolveKubeconfigPath(t *testing.T) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
t.Fatalf("could not determine home directory: %v", err)
|
||||
}
|
||||
defaultPath := filepath.Join(home, ".kube", "config")
|
||||
path, err := resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, defaultPath, path)
|
||||
|
||||
flagPath := "flag-path"
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("kubeconfig", "", "")
|
||||
err = cmd.Flags().Set("kubeconfig", flagPath)
|
||||
require.NoError(t, err)
|
||||
path, err = resolveKubeconfigPath(cmd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, flagPath, path)
|
||||
|
||||
envPath := "env-path"
|
||||
t.Setenv("KUBECONFIG", envPath)
|
||||
path, err = resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, envPath, path)
|
||||
}
|
||||
|
||||
func TestWriteKubeconfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existing string
|
||||
}{
|
||||
{
|
||||
name: "empty file",
|
||||
},
|
||||
{
|
||||
name: "existing content",
|
||||
existing: `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://foobar.com
|
||||
name: foo
|
||||
current-context: test
|
||||
kind: Config
|
||||
users: []
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
kubeconfigPath := filepath.Join(t.TempDir(), "config")
|
||||
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
kc := kubernetesCluster{
|
||||
name: "foo",
|
||||
url: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kc)
|
||||
require.NoError(t, err)
|
||||
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
require.NoError(t, err)
|
||||
expected := `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://example.com
|
||||
name: foo
|
||||
contexts:
|
||||
- context:
|
||||
cluster: foo
|
||||
namespace: default
|
||||
user: netbird
|
||||
name: foo
|
||||
current-context: foo
|
||||
kind: Config
|
||||
users:
|
||||
- name: netbird
|
||||
user:
|
||||
token: none
|
||||
`
|
||||
require.Equal(t, expected, string(b))
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -22,11 +22,19 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// extendSessionFlag drives the `netbird login --extend` flow: refresh the
|
||||
// SSO session expiry on the management server without tearing down the
|
||||
// tunnel. Mutually exclusive with setup-key login (a setup-key cannot
|
||||
// refresh an SSO-tracked peer — see auth.errSetupKeyOnSSOExpiredPeer).
|
||||
var extendSessionFlag bool
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
loginCmd.PersistentFlags().BoolVar(&extendSessionFlag, "extend", false,
|
||||
"refresh the SSO session expiry without tearing down the tunnel (requires an active connection)")
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
@@ -61,6 +69,16 @@ var loginCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
if extendSessionFlag {
|
||||
if providedSetupKey != "" {
|
||||
return fmt.Errorf("--extend cannot be combined with a setup key; setup keys can only enrol new peers")
|
||||
}
|
||||
if err := doExtendSession(ctx, cmd); err != nil {
|
||||
return fmt.Errorf("extend session failed: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// workaround to run without service
|
||||
if util.FindFirstLogPath(logFiles) == "" {
|
||||
if err := doForegroundLogin(ctx, cmd, providedSetupKey, activeProf); err != nil {
|
||||
@@ -150,6 +168,65 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
return nil
|
||||
}
|
||||
|
||||
// doExtendSession drives the daemon's RequestExtendAuthSession /
|
||||
// WaitExtendAuthSession pair. The user is sent through a regular SSO flow
|
||||
// (browser + verification URL) and the resulting JWT is forwarded to the
|
||||
// management server's ExtendAuthSession RPC. The tunnel stays up
|
||||
// throughout — no Down/Up, no network-map resync.
|
||||
func doExtendSession(ctx context.Context, cmd *cobra.Command) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
req := &proto.RequestExtendAuthSessionRequest{}
|
||||
// Pre-fill the IdP login hint from the active profile so the user
|
||||
// doesn't have to retype their email. Best-effort: we still proceed
|
||||
// without a hint if the lookup fails.
|
||||
pm := profilemanager.NewProfileManager()
|
||||
if active, perr := pm.GetActiveProfile(); perr == nil {
|
||||
if profState, sperr := pm.GetProfileState(active.Name); sperr == nil && profState.Email != "" {
|
||||
req.Hint = &profState.Email
|
||||
}
|
||||
}
|
||||
|
||||
startResp, err := client.RequestExtendAuthSession(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("start extend session: %v", err)
|
||||
}
|
||||
|
||||
uri := startResp.GetVerificationURIComplete()
|
||||
if uri == "" {
|
||||
uri = startResp.GetVerificationURI()
|
||||
}
|
||||
openURL(cmd, uri, startResp.GetUserCode(), noBrowser, showQR)
|
||||
|
||||
waitResp, err := client.WaitExtendAuthSession(ctx, &proto.WaitExtendAuthSessionRequest{
|
||||
DeviceCode: startResp.GetDeviceCode(),
|
||||
UserCode: startResp.GetUserCode(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("wait for extend session: %v", err)
|
||||
}
|
||||
|
||||
if ts := waitResp.GetSessionExpiresAt(); ts.IsValid() && !ts.AsTime().IsZero() {
|
||||
deadline := ts.AsTime().Local()
|
||||
cmd.Printf("Session extended. New expiry: %s\n", deadline.Format("2006-01-02 15:04:05 MST"))
|
||||
} else {
|
||||
// Management reported the peer is not eligible (e.g. login
|
||||
// expiration disabled on the account). Surface that fact
|
||||
// instead of pretending the call succeeded.
|
||||
cmd.Println("Session extension call completed, but the management server did not return a new deadline (peer may not be SSO-tracked or login expiration is disabled).")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) (*profilemanager.Profile, error) {
|
||||
// switch profile if provided
|
||||
|
||||
|
||||
@@ -95,9 +95,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// Execute runs the appropriate Cobra command for the CLI.
|
||||
// If the process is the update binary it delegates to updateCmd; otherwise it runs the root command.
|
||||
// It returns any error produced during command execution.
|
||||
// Execute executes the root command.
|
||||
func Execute() error {
|
||||
if isUpdateBinary() {
|
||||
return updateCmd.Execute()
|
||||
@@ -105,16 +103,6 @@ func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
// init initialises package-level defaults and configures the root
|
||||
// Cobra command tree. Sets platform-specific config / log directory
|
||||
// paths (including legacy Wiretrustee fallbacks) and a default daemon
|
||||
// address; registers persistent CLI flags (daemon address,
|
||||
// management / admin URLs, logging, setup key (file and inline,
|
||||
// mutually exclusive), preshared key, hostname, anonymise, config
|
||||
// path); attaches top-level and nested subcommands to the root
|
||||
// command; and registers `up`-specific persistent flags (external IP
|
||||
// maps, custom DNS resolver address, Rosenpass options, auto-connect
|
||||
// disabling, lazy connection).
|
||||
func init() {
|
||||
defaultConfigPathDir = "/etc/netbird/"
|
||||
defaultLogFileDir = "/var/log/netbird/"
|
||||
@@ -180,12 +168,6 @@ func init() {
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
debugCmd.AddCommand(debugConfigCmd)
|
||||
|
||||
// kubernetes commands
|
||||
rootCmd.AddCommand(kubernetesCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesListCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
|
||||
|
||||
// profile commands
|
||||
profileCmd.AddCommand(profileListCmd)
|
||||
|
||||
@@ -1,196 +0,0 @@
|
||||
//go:build privileged
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -27,6 +31,186 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestServiceEnvVars tests environment variable parsing
|
||||
func TestServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
@@ -117,6 +118,11 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var sessionExpiresAt time.Time
|
||||
if ts := resp.GetSessionExpiresAt(); ts.IsValid() {
|
||||
sessionExpiresAt = ts.AsTime().UTC()
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||
Anonymize: anonymizeFlag,
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
@@ -127,6 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
IPsFilter: ipsFilterMap,
|
||||
ConnectionTypeFilter: connectionTypeFilter,
|
||||
ProfileName: profName,
|
||||
SessionExpiresAt: sessionExpiresAt,
|
||||
})
|
||||
var statusOutputString string
|
||||
switch {
|
||||
|
||||
@@ -279,10 +279,6 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// Cancel the client context before stopping: Engine.Start blocks on the
|
||||
// signal stream while holding the engine mutex and only unblocks on
|
||||
// cancellation. Stopping first would deadlock on that mutex.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
@@ -446,8 +442,8 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP doesn't belong to an active peer
|
||||
// — offline roster peers are treated as unknown, same as foreign IPs.
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
@@ -468,7 +464,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
_ = engine.RunHealthProbes(false)
|
||||
_ = engine.RunHealthProbes(context.Background(), false)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
|
||||
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
|
||||
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
|
||||
// holding the engine mutex. When the Start context expires, the rollback path
|
||||
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
|
||||
func TestClientStartTimeoutRollback(t *testing.T) {
|
||||
signalAddr := startBlackholeSignal(t)
|
||||
mgmAddr := startManagement(t, signalAddr)
|
||||
|
||||
wgPort := 0
|
||||
client, err := New(Options{
|
||||
DeviceName: "embed-rollback-test",
|
||||
SetupKey: testSetupKey,
|
||||
ManagementURL: "http://" + mgmAddr,
|
||||
WireguardPort: &wgPort,
|
||||
})
|
||||
require.NoError(t, err, "embed client creation must succeed")
|
||||
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
startErr := make(chan error, 1)
|
||||
go func() {
|
||||
startErr <- client.Start(startCtx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-startErr:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(60 * time.Second):
|
||||
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
|
||||
}
|
||||
}
|
||||
|
||||
// startBlackholeSignal starts a gRPC server without the SignalExchange service
|
||||
// registered. Connections succeed, but the signal stream can never be
|
||||
// established, which keeps Engine.Start parked in WaitStreamConnected.
|
||||
func startBlackholeSignal(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string) string {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: t.TempDir(),
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
permissionsManager := permissions.NewManager(testStore)
|
||||
peersManager := peers.NewManager(testStore, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, testStore, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
require.NoError(t, err)
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
|
||||
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
require.NoError(t, err)
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
require.NoError(t, err)
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build privileged
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android && privileged
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build privileged
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android && privileged
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build privileged
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android && privileged
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux || !privileged
|
||||
//go:build !linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android && privileged
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
@@ -26,6 +26,64 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||
wgPort := 51852
|
||||
@@ -198,64 +256,6 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||
wgPort := 51856
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
!define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
|
||||
!define INSTALLER_NAME "netbird-installer.exe"
|
||||
!define MAIN_APP_EXE "Netbird"
|
||||
!define ICON "ui\\assets\\netbird.ico"
|
||||
!define ICON "ui\\build\\windows\\icon.ico"
|
||||
!define BANNER "ui\\build\\banner.bmp"
|
||||
!define LICENSE_DATA "..\\LICENSE"
|
||||
|
||||
@@ -280,6 +280,43 @@ CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
SectionEnd
|
||||
|
||||
# Install the Microsoft Edge WebView2 runtime if it isn't already present.
|
||||
# Macro adapted from Wails3's NSIS template (wails_tools.nsh): a registry
|
||||
# probe followed by a silent install of the embedded evergreen bootstrapper.
|
||||
# The MicrosoftEdgeWebview2Setup.exe payload is staged next to this script
|
||||
# by the sign-pipelines build step (`wails3 generate webview2bootstrapper`).
|
||||
!macro nb.webview2runtime
|
||||
SetRegView 64
|
||||
# Per-machine install marker — populated when the runtime ships with
|
||||
# Edge or has been installed by an admin previously.
|
||||
ReadRegStr $0 HKLM "SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
|
||||
${If} $0 != ""
|
||||
Goto webview2_ok
|
||||
${EndIf}
|
||||
# Per-user fallback for HKCU installs.
|
||||
ReadRegStr $0 HKCU "Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
|
||||
${If} $0 != ""
|
||||
Goto webview2_ok
|
||||
${EndIf}
|
||||
|
||||
SetDetailsPrint both
|
||||
DetailPrint "Installing: WebView2 Runtime"
|
||||
SetDetailsPrint listonly
|
||||
|
||||
InitPluginsDir
|
||||
CreateDirectory "$pluginsdir\webview2bootstrapper"
|
||||
SetOutPath "$pluginsdir\webview2bootstrapper"
|
||||
File "MicrosoftEdgeWebview2Setup.exe"
|
||||
ExecWait '"$pluginsdir\webview2bootstrapper\MicrosoftEdgeWebview2Setup.exe" /silent /install'
|
||||
|
||||
SetDetailsPrint both
|
||||
webview2_ok:
|
||||
!macroend
|
||||
|
||||
Section -WebView2
|
||||
!insertmacro nb.webview2runtime
|
||||
SectionEnd
|
||||
|
||||
Section -Post
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install'
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service start'
|
||||
@@ -326,9 +363,9 @@ DetailPrint "Deleting application files..."
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
!if ${ARCH} == "amd64"
|
||||
# Legacy: pre-Wails installs shipped opengl32.dll (Mesa3D for Fyne); remove
|
||||
# any leftover copy on uninstall so old upgrades don't leave it behind.
|
||||
Delete "$INSTDIR\opengl32.dll"
|
||||
!endif
|
||||
DetailPrint "Removing application directory..."
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +22,25 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// peerLoginExpiredMsg is the exact phrase the management server returns
|
||||
// when a previously SSO-enrolled peer's login has expired. Sourced from
|
||||
// shared/management/status/error.go (NewPeerLoginExpiredError). Matched
|
||||
// by substring so a future server-side rewording that keeps the phrase
|
||||
// still triggers the friendly fallback in Login().
|
||||
const peerLoginExpiredMsg = "peer login has expired"
|
||||
|
||||
// errSetupKeyOnSSOExpiredPeer replaces the raw management error when the
|
||||
// user runs `netbird login -k <setup-key>` against a peer that was
|
||||
// originally enrolled via SSO. Wrapped in a PermissionDenied gRPC status
|
||||
// so callers' existing isPermissionDenied / isAuthError checks still
|
||||
// classify it correctly (early-exit from retry backoff, StatusNeedsLogin
|
||||
// in the server state machine).
|
||||
var errSetupKeyOnSSOExpiredPeer = status.Error(
|
||||
codes.PermissionDenied,
|
||||
"this peer was originally enrolled via SSO and its session has expired. "+
|
||||
"Setup keys can only enrol new peers — run `netbird up` (interactive SSO) to re-login.",
|
||||
)
|
||||
|
||||
// Auth manages authentication operations with the management server
|
||||
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||
type Auth struct {
|
||||
@@ -184,6 +204,15 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
// The peer pub-key is already on file with the management
|
||||
// server (originally enrolled via SSO) and the session has
|
||||
// expired. The setup-key path can only enrol new peers, so
|
||||
// retrying with -k will keep failing. Replace the raw mgm
|
||||
// message with an actionable hint that tells the user to
|
||||
// re-authenticate via SSO instead.
|
||||
if setupKey != "" && jwtToken == "" && isPeerLoginExpired(err) {
|
||||
err = errSetupKeyOnSSOExpiredPeer
|
||||
}
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
@@ -474,3 +503,16 @@ func isLoginNeeded(err error) bool {
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
return isPermissionDenied(err)
|
||||
}
|
||||
|
||||
// isPeerLoginExpired reports whether err is the management server's
|
||||
// "peer login has expired" PermissionDenied response. Used by Login to
|
||||
// detect the case where the caller passed a setup-key but the peer is
|
||||
// actually an SSO-enrolled record whose session needs refreshing — the
|
||||
// setup-key path cannot help there.
|
||||
func isPeerLoginExpired(err error) bool {
|
||||
if !isPermissionDenied(err) {
|
||||
return false
|
||||
}
|
||||
s, _ := status.FromError(err)
|
||||
return strings.Contains(s.Message(), peerLoginExpiredMsg)
|
||||
}
|
||||
|
||||
80
client/internal/auth/auth_test.go
Normal file
80
client/internal/auth/auth_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestIsPeerLoginExpired(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "plain error (not a gRPC status)",
|
||||
err: errors.New("network read: connection reset"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "PermissionDenied with different message",
|
||||
err: status.Error(codes.PermissionDenied, "user is blocked"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated with the expected phrase",
|
||||
// Wrong status code — must still return false.
|
||||
err: status.Error(codes.Unauthenticated, "peer login has expired, please log in once more"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "exact server message",
|
||||
err: status.Error(codes.PermissionDenied, "peer login has expired, please log in once more"),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "phrase as substring",
|
||||
// Future-proofing: if mgm reworords but keeps the phrase,
|
||||
// the friendly fallback must still kick in.
|
||||
err: status.Error(codes.PermissionDenied, "session refused: peer login has expired (account=foo)"),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := isPeerLoginExpired(tc.err); got != tc.want {
|
||||
t.Fatalf("isPeerLoginExpired(%v) = %v, want %v", tc.err, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrSetupKeyOnSSOExpiredPeer(t *testing.T) {
|
||||
// Sentinel must surface as PermissionDenied so the upstream
|
||||
// isPermissionDenied / isAuthError checks classify it correctly
|
||||
// (short-circuit retry backoff, set StatusNeedsLogin).
|
||||
if !isPermissionDenied(errSetupKeyOnSSOExpiredPeer) {
|
||||
t.Fatalf("errSetupKeyOnSSOExpiredPeer must be a PermissionDenied gRPC error")
|
||||
}
|
||||
|
||||
// Message must actually mention SSO and `netbird up` so it is
|
||||
// actionable for the end user. Loose substring checks keep the
|
||||
// test resilient to copy edits.
|
||||
s, _ := status.FromError(errSetupKeyOnSSOExpiredPeer)
|
||||
msg := strings.ToLower(s.Message())
|
||||
for _, want := range []string{"sso", "netbird up"} {
|
||||
if !strings.Contains(msg, want) {
|
||||
t.Errorf("sentinel message should contain %q, got %q", want, s.Message())
|
||||
}
|
||||
}
|
||||
}
|
||||
89
client/internal/auth/pending_flow.go
Normal file
89
client/internal/auth/pending_flow.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PendingFlow stores an in-progress OAuth flow between the RPC that
|
||||
// initiates it (returns the verification URI to the UI) and the RPC
|
||||
// that waits for the user to complete it. The flow handle, the
|
||||
// device-code info, and the absolute expiry are kept together so the
|
||||
// waiting RPC can validate the device code and reuse the same flow.
|
||||
//
|
||||
// PendingFlow is safe for concurrent use; callers must not access the
|
||||
// stored fields directly.
|
||||
type PendingFlow struct {
|
||||
mu sync.Mutex
|
||||
flow OAuthFlow
|
||||
info AuthFlowInfo
|
||||
expiresAt time.Time
|
||||
waitCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewPendingFlow returns an empty PendingFlow ready to be populated by Set.
|
||||
func NewPendingFlow() *PendingFlow {
|
||||
return &PendingFlow{}
|
||||
}
|
||||
|
||||
// Set stores the flow and its authorization info, computing the absolute
|
||||
// expiry from info.ExpiresIn (seconds, as returned by the IdP).
|
||||
func (p *PendingFlow) Set(flow OAuthFlow, info AuthFlowInfo) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.flow = flow
|
||||
p.info = info
|
||||
p.expiresAt = time.Now().Add(time.Duration(info.ExpiresIn) * time.Second)
|
||||
}
|
||||
|
||||
// Get returns the stored flow, info, and whether a flow is currently
|
||||
// pending. Returns (nil, zero, false) after Clear or before Set.
|
||||
func (p *PendingFlow) Get() (OAuthFlow, AuthFlowInfo, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.flow == nil {
|
||||
return nil, AuthFlowInfo{}, false
|
||||
}
|
||||
return p.flow, p.info, true
|
||||
}
|
||||
|
||||
// ExpiresAt returns the absolute expiry of the pending flow. Returns
|
||||
// the zero time when no flow is pending.
|
||||
func (p *PendingFlow) ExpiresAt() time.Time {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.expiresAt
|
||||
}
|
||||
|
||||
// SetWaitCancel records the cancel function for the goroutine currently
|
||||
// blocked in WaitToken so a new RequestAuth can preempt it.
|
||||
func (p *PendingFlow) SetWaitCancel(cancel context.CancelFunc) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.waitCancel = cancel
|
||||
}
|
||||
|
||||
// CancelWait invokes and clears the stored wait-cancel, if any. Safe to
|
||||
// call when no wait is in progress.
|
||||
func (p *PendingFlow) CancelWait() {
|
||||
p.mu.Lock()
|
||||
cancel := p.waitCancel
|
||||
p.waitCancel = nil
|
||||
p.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Clear resets the pending flow to empty. Any stored wait-cancel is
|
||||
// dropped without being invoked — call CancelWait first if the waiting
|
||||
// goroutine must be stopped.
|
||||
func (p *PendingFlow) Clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.flow = nil
|
||||
p.info = AuthFlowInfo{}
|
||||
p.expiresAt = time.Time{}
|
||||
p.waitCancel = nil
|
||||
}
|
||||
82
client/internal/auth/sessionwatch/event.go
Normal file
82
client/internal/auth/sessionwatch/event.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package sessionwatch
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// internal event kinds are no longer exposed: the watcher drives the Sink
|
||||
// directly (NotifyStateChange on deadline change/clear, PublishEvent at
|
||||
// each warning lead). Tests use a mock Sink to observe what the watcher
|
||||
// emits.
|
||||
|
||||
// Metadata keys attached by the daemon to session-warning SystemEvents.
|
||||
// The UI tray reads these to build a locale-aware notification without
|
||||
// relying on the daemon's locale-less UserMessage string, and to
|
||||
// disambiguate the T-WarningLead notification from the T-FinalWarningLead
|
||||
// fallback that auto-opens the SessionAboutToExpire dialog.
|
||||
const (
|
||||
// MetaSessionWarning is set to "true" on both warning events (T-10 and
|
||||
// T-2) so the UI can detect a session-warning SystemEvent without
|
||||
// matching on the message text. Use MetaSessionFinal to distinguish
|
||||
// the two.
|
||||
MetaSessionWarning = "session_warning"
|
||||
// MetaSessionFinal is set to "true" on the T-FinalWarningLead event
|
||||
// only. Consumers that need to auto-open the SessionAboutToExpire
|
||||
// dialog gate on this; T-WarningLead events leave the field unset.
|
||||
MetaSessionFinal = "session_final_warning"
|
||||
// MetaSessionExpiresAt carries the absolute UTC deadline encoded with
|
||||
// FormatExpiresAt; consumers must decode with ParseExpiresAt so a
|
||||
// future format change stays a single edit.
|
||||
MetaSessionExpiresAt = "session_expires_at"
|
||||
// MetaSessionLeadMinutes carries the lead in whole minutes (WarningLead
|
||||
// for the T-10 event, FinalWarningLead for the T-2 event) so the UI
|
||||
// can show "expires in ~N minutes" without hardcoding either constant.
|
||||
MetaSessionLeadMinutes = "lead_minutes"
|
||||
// MetaSessionDeadlineRejected is attached to the ERROR/AUTHENTICATION
|
||||
// SystemEvent the daemon emits when it discards a deadline from the
|
||||
// management server (pre-epoch, too far in the future, or past the
|
||||
// clock-skew tolerance). The value is the rejection reason string.
|
||||
// userMessage is left empty; the UI detects the event via this key
|
||||
// and builds a localized notification — same pattern as the session
|
||||
// warnings above.
|
||||
MetaSessionDeadlineRejected = "session_deadline_rejected"
|
||||
)
|
||||
|
||||
// expiresAtLayout is the wire format used for MetaSessionExpiresAt.
|
||||
// Producer and consumers both go through FormatExpiresAt/ParseExpiresAt
|
||||
// so this layout stays a single source of truth.
|
||||
const expiresAtLayout = time.RFC3339
|
||||
|
||||
// FormatExpiresAt encodes a deadline for MetaSessionExpiresAt. Always
|
||||
// emits UTC so a consumer in another timezone reads the same wall-clock
|
||||
// deadline.
|
||||
func FormatExpiresAt(t time.Time) string {
|
||||
return t.UTC().Format(expiresAtLayout)
|
||||
}
|
||||
|
||||
// ParseExpiresAt decodes the MetaSessionExpiresAt value back to a UTC
|
||||
// time. Returns an error when the field is empty or malformed; the
|
||||
// caller decides whether to fall back (zero value) or propagate.
|
||||
func ParseExpiresAt(s string) (time.Time, error) {
|
||||
t, err := time.Parse(expiresAtLayout, s)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return t.UTC(), nil
|
||||
}
|
||||
|
||||
// FormatLeadMinutes encodes a lead duration for MetaSessionLeadMinutes
|
||||
// as the integer count of whole minutes. Sub-minute residuals are
|
||||
// truncated — the field is informational ("expires in ~N minutes") and
|
||||
// fractional minutes don't change what the UI displays.
|
||||
func FormatLeadMinutes(d time.Duration) string {
|
||||
return strconv.Itoa(int(d / time.Minute))
|
||||
}
|
||||
|
||||
// ParseLeadMinutes decodes a MetaSessionLeadMinutes value. Returns 0
|
||||
// and the parse error for malformed input; consumers that prefer a
|
||||
// silent fallback can simply ignore the error.
|
||||
func ParseLeadMinutes(s string) (int, error) {
|
||||
return strconv.Atoi(s)
|
||||
}
|
||||
387
client/internal/auth/sessionwatch/watcher.go
Normal file
387
client/internal/auth/sessionwatch/watcher.go
Normal file
@@ -0,0 +1,387 @@
|
||||
// Package sessionwatch tracks the SSO session expiry deadline that the
|
||||
// management server publishes via LoginResponse / SyncResponse and fires
|
||||
// two warning events at fixed lead times before expiry: an interactive
|
||||
// T-WarningLead notification and a dismiss-gated T-FinalWarningLead
|
||||
// fallback dialog.
|
||||
//
|
||||
// The watcher is idempotent: Update may be called as often as the network
|
||||
// map snapshots arrive. Repeating the same deadline is a no-op; a new
|
||||
// deadline reschedules the timers and arms a fresh warning cycle.
|
||||
//
|
||||
// Warning firing is edge-detected. Each unique deadline value fires each
|
||||
// warning callback at most once.
|
||||
package sessionwatch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
// Skew tolerates a small clock difference between the management
|
||||
// server and this peer before treating a deadline as "in the past".
|
||||
// Slightly above typical NTP drift; tight enough that the UI doesn't
|
||||
// paint a stale expiry as if it were valid.
|
||||
Skew = 30 * time.Second
|
||||
|
||||
// maxDeadlineHorizon caps how far in the future an accepted deadline
|
||||
// can sit. A timestamp beyond this is almost certainly a protocol
|
||||
// glitch, and silently arming a 100-year timer would hide the bug.
|
||||
maxDeadlineHorizon = 10 * 365 * 24 * time.Hour
|
||||
|
||||
// WarningLead is how far before expiry the first (interactive)
|
||||
// warning fires. Drives the T-10 OS notification with
|
||||
// Extend/Dismiss actions.
|
||||
WarningLead = 10 * time.Minute
|
||||
|
||||
// FinalWarningLead is how far before expiry the fallback final
|
||||
// warning fires. Drives the auto-opened SessionAboutToExpire dialog,
|
||||
// but only when the user has not dismissed the T-WarningLead warning
|
||||
// for the same deadline. Must be strictly less than WarningLead.
|
||||
FinalWarningLead = 2 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDeadlineBeforeEpoch is returned by Update when the supplied
|
||||
// deadline pre-dates 1970-01-01.
|
||||
ErrDeadlineBeforeEpoch = errors.New("session deadline before unix epoch")
|
||||
|
||||
// ErrDeadlineTooFarFuture is returned by Update when the supplied
|
||||
// deadline is more than maxDeadlineHorizon in the future.
|
||||
ErrDeadlineTooFarFuture = errors.New("session deadline too far in the future")
|
||||
|
||||
// ErrDeadlineInPast is returned by Update when the supplied deadline
|
||||
// is more than Skew in the past.
|
||||
ErrDeadlineInPast = errors.New("session deadline in the past")
|
||||
)
|
||||
|
||||
// StatusRecorder is the side-effect surface the watcher drives on every
|
||||
// state transition. Production wires this to peer.Status (SetSessionExpiresAt
|
||||
// for deadline change/clear, PublishEvent for the two warnings); tests pass
|
||||
// a fake recorder so the same surface is observable without an engine.
|
||||
//
|
||||
// The watcher is the single owner of the deadline propagated to the
|
||||
// recorder: every set, clear, sanity-check rejection and Close routes the
|
||||
// value through SetSessionExpiresAt, so the SubscribeStatus snapshot the UI
|
||||
// reads can never drift from the watcher's timer state. (SetSessionExpiresAt
|
||||
// fans out its own state-change notification, so no separate notify is
|
||||
// needed.) The recorder is server-scoped and outlives this engine-scoped
|
||||
// watcher — without the Close-time clear a teardown (Down, or the Down+Up of
|
||||
// a profile switch) would leave the next session showing the previous one's
|
||||
// stale "expires in" value.
|
||||
//
|
||||
// PublishEvent's signature mirrors peer.Status.PublishEvent: the watcher
|
||||
// composes the metadata internally so the wire format (MetaSession*) is
|
||||
// owned by sessionwatch, not the caller.
|
||||
type StatusRecorder interface {
|
||||
SetSessionExpiresAt(deadline time.Time)
|
||||
PublishEvent(
|
||||
severity cProto.SystemEvent_Severity,
|
||||
category cProto.SystemEvent_Category,
|
||||
message string,
|
||||
userMessage string,
|
||||
metadata map[string]string,
|
||||
)
|
||||
}
|
||||
|
||||
// Watcher observes the latest session deadline and fires two warnings
|
||||
// before it expires: the interactive T-WarningLead notification, and the
|
||||
// fallback T-FinalWarningLead dialog (suppressed when the user dismissed
|
||||
// the first one for the same deadline). Safe for concurrent use.
|
||||
type Watcher struct {
|
||||
lead time.Duration
|
||||
finalLead time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
current time.Time
|
||||
timer *time.Timer
|
||||
finalTimer *time.Timer
|
||||
firedAt time.Time // deadline value the T-WarningLead callback last fired against
|
||||
finalFiredAt time.Time // deadline value the T-FinalWarningLead callback last fired against
|
||||
dismissedAt time.Time // deadline value the user dismissed via Dismiss(); gates fireFinal
|
||||
closed bool
|
||||
recorder StatusRecorder
|
||||
}
|
||||
|
||||
// New returns a watcher with the package defaults WarningLead and
|
||||
// FinalWarningLead. Pass nil for recorder to silence side effects (handy
|
||||
// in unit tests that exercise sanity checks without observing the publish
|
||||
// path).
|
||||
func New(recorder StatusRecorder) *Watcher {
|
||||
return NewWithLeads(WarningLead, FinalWarningLead, recorder)
|
||||
}
|
||||
|
||||
// NewWithLeads returns a watcher with custom lead times. Useful for tests.
|
||||
// final must be strictly less than lead; otherwise both timers fire in the
|
||||
// wrong order or simultaneously and the UI flow breaks. A zero final lead
|
||||
// disables the final-warning timer entirely (see armTimerLocked) so a
|
||||
// millisecond-scale deadline doesn't flush both timers in one tick.
|
||||
func NewWithLeads(lead, final time.Duration, recorder StatusRecorder) *Watcher {
|
||||
return &Watcher{
|
||||
lead: lead,
|
||||
finalLead: final,
|
||||
recorder: recorder,
|
||||
}
|
||||
}
|
||||
|
||||
// Update sets the latest deadline. Pass the zero time to clear (e.g. when
|
||||
// a Sync push from the server omits the field because login expiration
|
||||
// was disabled).
|
||||
//
|
||||
// Same-value updates are no-ops. A different non-zero value cancels any
|
||||
// pending timer, resets the "already fired" guard, and arms a new one.
|
||||
//
|
||||
// Returns one of the sentinel Err* values when the deadline fails the
|
||||
// sanity checks (pre-epoch, far future, or in the past beyond Skew).
|
||||
// In every error case the watcher first clears its state so it stays
|
||||
// consistent with what the caller will push into its other sinks (e.g.
|
||||
// applySessionDeadline forces a zero deadline into the status recorder
|
||||
// after a non-nil error).
|
||||
func (w *Watcher) Update(deadline time.Time) error {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
if deadline.IsZero() {
|
||||
w.clearLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
switch {
|
||||
case deadline.Before(time.Unix(0, 0)):
|
||||
w.clearLocked()
|
||||
return fmt.Errorf("%w: %v", ErrDeadlineBeforeEpoch, deadline)
|
||||
case deadline.After(now.Add(maxDeadlineHorizon)):
|
||||
w.clearLocked()
|
||||
return fmt.Errorf("%w: %v", ErrDeadlineTooFarFuture, deadline)
|
||||
case deadline.Before(now.Add(-Skew)):
|
||||
w.clearLocked()
|
||||
return fmt.Errorf("%w: %v (now=%v)", ErrDeadlineInPast, deadline, now)
|
||||
}
|
||||
|
||||
if deadline.Equal(w.current) {
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
w.stopTimerLocked()
|
||||
w.current = deadline
|
||||
// Reset every per-deadline guard so a refreshed deadline arms a fresh
|
||||
// warning cycle: both edge triggers and the user Dismiss decision
|
||||
// (the user agreed to the old deadline expiring; a new deadline
|
||||
// restarts the contract).
|
||||
w.firedAt = time.Time{}
|
||||
w.finalFiredAt = time.Time{}
|
||||
w.dismissedAt = time.Time{}
|
||||
|
||||
w.armTimerLocked(deadline)
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder != nil {
|
||||
recorder.SetSessionExpiresAt(deadline)
|
||||
}
|
||||
log.Infof("auth session deadline set to: %s (in %s)", deadline.Format(time.RFC3339), time.Until(deadline).Round(time.Second))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deadline returns the most recently observed deadline. Zero when no
|
||||
// deadline is currently tracked.
|
||||
func (w *Watcher) Deadline() time.Time {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.current
|
||||
}
|
||||
|
||||
// Dismiss records the user's "Dismiss" action against the current deadline
|
||||
// and suppresses the upcoming final-warning callback for that deadline.
|
||||
// Idempotent: repeated calls are no-ops. A subsequent Update with a fresh
|
||||
// deadline resets the dismissal so the final-warning cycle re-arms.
|
||||
//
|
||||
// No-op when the watcher holds no deadline or has been closed.
|
||||
func (w *Watcher) Dismiss() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.closed || w.current.IsZero() {
|
||||
return
|
||||
}
|
||||
if w.dismissedAt.Equal(w.current) {
|
||||
return
|
||||
}
|
||||
w.dismissedAt = w.current
|
||||
// Cancel the armed final-warning timer eagerly. fireFinal would also
|
||||
// gate on dismissedAt, but stopping the timer avoids a wakeup with
|
||||
// nothing to do and makes the intent visible.
|
||||
if w.finalTimer != nil {
|
||||
w.finalTimer.Stop()
|
||||
w.finalTimer = nil
|
||||
}
|
||||
log.Infof("auth session final-warning dismissed for deadline %s", w.current.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
// Close stops any pending timer and drops the deadline on the status
|
||||
// recorder. Update calls after Close are ignored. Clearing the recorder
|
||||
// here is what keeps a teardown (Down, or the Down+Up of a profile switch)
|
||||
// from leaving the next session showing this one's stale "expires in"
|
||||
// value — the recorder is server-scoped and outlives this engine-scoped
|
||||
// watcher, so nothing else drops the anchor on teardown.
|
||||
func (w *Watcher) Close() {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.closed = true
|
||||
w.stopTimerLocked()
|
||||
hadDeadline := !w.current.IsZero()
|
||||
w.current = time.Time{}
|
||||
w.firedAt = time.Time{}
|
||||
w.finalFiredAt = time.Time{}
|
||||
w.dismissedAt = time.Time{}
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder != nil && hadDeadline {
|
||||
recorder.SetSessionExpiresAt(time.Time{})
|
||||
}
|
||||
}
|
||||
|
||||
// clearLocked drops the tracked deadline and notifies the recorder so
|
||||
// downstream consumers (SubscribeStatus stream, UI) drop their anchor.
|
||||
// The caller must hold w.mu; this helper releases it before invoking
|
||||
// the recorder.
|
||||
func (w *Watcher) clearLocked() {
|
||||
if w.current.IsZero() {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.stopTimerLocked()
|
||||
w.current = time.Time{}
|
||||
w.firedAt = time.Time{}
|
||||
w.finalFiredAt = time.Time{}
|
||||
w.dismissedAt = time.Time{}
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder != nil {
|
||||
recorder.SetSessionExpiresAt(time.Time{})
|
||||
}
|
||||
log.Infof("auth session deadline cleared")
|
||||
}
|
||||
|
||||
func (w *Watcher) stopTimerLocked() {
|
||||
if w.timer != nil {
|
||||
w.timer.Stop()
|
||||
w.timer = nil
|
||||
}
|
||||
if w.finalTimer != nil {
|
||||
w.finalTimer.Stop()
|
||||
w.finalTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) armTimerLocked(deadline time.Time) {
|
||||
w.timer = armOneShotLocked(deadline.Add(-w.lead), func() { w.fire(deadline) })
|
||||
// finalLead <= 0 disables the final-warning timer entirely. Used by
|
||||
// tests that predate the final-warning fallback so a millisecond-scale
|
||||
// deadline does not flush both timers at once.
|
||||
if w.finalLead > 0 {
|
||||
w.finalTimer = armOneShotLocked(deadline.Add(-w.finalLead), func() { w.fireFinal(deadline) })
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) fire(armedFor time.Time) {
|
||||
w.mu.Lock()
|
||||
if w.closed || !w.current.Equal(armedFor) {
|
||||
// Deadline moved while we were waiting (e.g. a successful extend).
|
||||
// The reschedule path armed a fresh timer; this one is stale.
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if !w.firedAt.IsZero() && w.firedAt.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.firedAt = armedFor
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder == nil {
|
||||
return
|
||||
}
|
||||
log.Infof("auth session expiry soon warning fired")
|
||||
publishWarning(recorder, armedFor, false)
|
||||
}
|
||||
|
||||
// fireFinal mirrors fire for the T-FinalWarningLead timer with an extra
|
||||
// dismiss-gate: if the user dismissed the T-WarningLead notification for
|
||||
// this deadline, the final warning is suppressed entirely.
|
||||
func (w *Watcher) fireFinal(armedFor time.Time) {
|
||||
w.mu.Lock()
|
||||
if w.closed || !w.current.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if !w.finalFiredAt.IsZero() && w.finalFiredAt.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if w.dismissedAt.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
log.Infof("auth session final-warning skipped (dismissed by user)")
|
||||
return
|
||||
}
|
||||
w.finalFiredAt = armedFor
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder == nil {
|
||||
return
|
||||
}
|
||||
log.Infof("auth session final-warning fired")
|
||||
publishWarning(recorder, armedFor, true)
|
||||
}
|
||||
|
||||
// armOneShotLocked schedules cb at fireAt. When fireAt is already in the
|
||||
// past it dispatches on the next scheduler tick so a state-change recorder
|
||||
// notification (invoked after w.mu is released) lands first. Caller must
|
||||
// hold w.mu.
|
||||
func armOneShotLocked(fireAt time.Time, cb func()) *time.Timer {
|
||||
delay := time.Until(fireAt)
|
||||
if delay <= 0 {
|
||||
return time.AfterFunc(0, cb)
|
||||
}
|
||||
return time.AfterFunc(delay, cb)
|
||||
}
|
||||
|
||||
// publishWarning composes the SystemEvent for a watcher-fired warning and
|
||||
// pushes it through the recorder. Severity is CRITICAL on both — bypassing
|
||||
// the user's Notifications toggle is deliberate: missing the warning
|
||||
// window forces the post-mortem SessionExpired flow (tunnel torn down,
|
||||
// lock icon, manual re-login), which is the UX we are trying to avoid.
|
||||
func publishWarning(recorder StatusRecorder, deadline time.Time, final bool) {
|
||||
lead := WarningLead
|
||||
message := "session expiry warning"
|
||||
meta := map[string]string{
|
||||
MetaSessionWarning: "true",
|
||||
MetaSessionExpiresAt: FormatExpiresAt(deadline),
|
||||
}
|
||||
if final {
|
||||
lead = FinalWarningLead
|
||||
message = "session expiry final warning"
|
||||
meta[MetaSessionFinal] = "true"
|
||||
}
|
||||
meta[MetaSessionLeadMinutes] = FormatLeadMinutes(lead)
|
||||
|
||||
recorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_AUTHENTICATION,
|
||||
message,
|
||||
"",
|
||||
meta,
|
||||
)
|
||||
}
|
||||
519
client/internal/auth/sessionwatch/watcher_test.go
Normal file
519
client/internal/auth/sessionwatch/watcher_test.go
Normal file
@@ -0,0 +1,519 @@
|
||||
package sessionwatch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// fakeRecorder satisfies StatusRecorder and records every call so tests
|
||||
// can observe what the watcher emits. SetSessionExpiresAt and PublishEvent
|
||||
// land in the same ordered events slice (with the Kind distinguishing
|
||||
// them) so tests that care about ordering still work. lastDeadline holds
|
||||
// the most recent value passed to SetSessionExpiresAt so tests can assert
|
||||
// the recorder ended up cleared/set as expected.
|
||||
type fakeRecorder struct {
|
||||
mu sync.Mutex
|
||||
events []event
|
||||
lastDeadline time.Time
|
||||
}
|
||||
|
||||
type eventKind int
|
||||
|
||||
const (
|
||||
stateChange eventKind = iota
|
||||
publish
|
||||
)
|
||||
|
||||
type event struct {
|
||||
kind eventKind
|
||||
// Set only for publish events.
|
||||
severity cProto.SystemEvent_Severity
|
||||
category cProto.SystemEvent_Category
|
||||
message string
|
||||
meta map[string]string
|
||||
}
|
||||
|
||||
// SetSessionExpiresAt mirrors peer.Status: a same-value write is a no-op,
|
||||
// a real change records the new value and fans out a state-change (the
|
||||
// production recorder calls notifyStateChange internally). The baseline
|
||||
// is the zero time, so an initial clear before any deadline is set emits
|
||||
// nothing — matching the real recorder.
|
||||
func (r *fakeRecorder) SetSessionExpiresAt(deadline time.Time) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.lastDeadline.Equal(deadline) {
|
||||
return
|
||||
}
|
||||
r.lastDeadline = deadline
|
||||
r.events = append(r.events, event{kind: stateChange})
|
||||
}
|
||||
|
||||
func (r *fakeRecorder) deadline() time.Time {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.lastDeadline
|
||||
}
|
||||
|
||||
func (r *fakeRecorder) PublishEvent(
|
||||
severity cProto.SystemEvent_Severity,
|
||||
category cProto.SystemEvent_Category,
|
||||
message string,
|
||||
_ string,
|
||||
metadata map[string]string,
|
||||
) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.events = append(r.events, event{
|
||||
kind: publish,
|
||||
severity: severity,
|
||||
category: category,
|
||||
message: message,
|
||||
meta: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *fakeRecorder) snapshot() []event {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([]event, len(r.events))
|
||||
copy(out, r.events)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e event) isFinalWarning() bool {
|
||||
return e.kind == publish && e.meta[MetaSessionFinal] == "true"
|
||||
}
|
||||
|
||||
func (e event) isWarning() bool {
|
||||
return e.kind == publish && e.meta[MetaSessionWarning] == "true" && e.meta[MetaSessionFinal] != "true"
|
||||
}
|
||||
|
||||
func countWhere(events []event, pred func(event) bool) int {
|
||||
n := 0
|
||||
for _, e := range events {
|
||||
if pred(e) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func waitForEvents(t *testing.T, r *fakeRecorder, want int) []event {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if got := r.snapshot(); len(got) >= want {
|
||||
return got
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
got := r.snapshot()
|
||||
t.Fatalf("timed out waiting for %d events, got %d: %+v", want, len(got), got)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newWatcher builds a watcher with the final timer disabled (finalLead=0),
|
||||
// matching the lead-only behaviour the pre-final-warning tests assume.
|
||||
func newWatcher(lead time.Duration, r *fakeRecorder) *Watcher {
|
||||
return NewWithLeads(lead, 0, r)
|
||||
}
|
||||
|
||||
func TestUpdateZeroBeforeAnythingIsNoop(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
_ = w.Update(time.Time{})
|
||||
|
||||
if got := r.snapshot(); len(got) != 0 {
|
||||
t.Fatalf("expected no events on initial zero, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNonZeroFiresStateChange(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(time.Hour)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 1)
|
||||
if events[0].kind != stateChange {
|
||||
t.Fatalf("expected stateChange, got %+v", events[0])
|
||||
}
|
||||
if !w.Deadline().Equal(d) {
|
||||
t.Fatalf("deadline mismatch: %v vs %v", w.Deadline(), d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSameDeadlineIsNoop(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(time.Hour)
|
||||
_ = w.Update(d)
|
||||
_ = w.Update(d)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 1)
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected exactly 1 event for repeated same deadline, got %d: %+v", len(events), events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarningFiresOnceWithinLeadWindow(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
lead := 50 * time.Millisecond
|
||||
w := newWatcher(lead, r)
|
||||
defer w.Close()
|
||||
|
||||
// Deadline 80ms out — warning should fire after ~30ms.
|
||||
d := time.Now().Add(80 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 2)
|
||||
if events[0].kind != stateChange {
|
||||
t.Fatalf("event[0] should be stateChange, got %+v", events[0])
|
||||
}
|
||||
if !events[1].isWarning() {
|
||||
t.Fatalf("event[1] should be a warning publish, got %+v", events[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarningFiresImmediatelyWhenAlreadyInsideWindow(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r) // lead > delta => fire immediately
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(10 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 2)
|
||||
if !events[1].isWarning() {
|
||||
t.Fatalf("expected immediate warning publish, got %+v", events[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDeadlineCancelsPriorTimer(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
lead := 50 * time.Millisecond
|
||||
w := newWatcher(lead, r)
|
||||
defer w.Close()
|
||||
|
||||
first := time.Now().Add(80 * time.Millisecond) // would fire warning ~30ms in
|
||||
_ = w.Update(first)
|
||||
|
||||
// Replace with a far-future deadline before the warning fires.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
second := time.Now().Add(time.Hour)
|
||||
_ = w.Update(second)
|
||||
|
||||
// Wait past when first's warning would have fired.
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
|
||||
if n := countWhere(r.snapshot(), event.isWarning); n != 0 {
|
||||
t.Fatalf("warning fired for cancelled deadline: %+v", r.snapshot())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAfterFireArmsNewWarning(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
lead := 30 * time.Millisecond
|
||||
w := newWatcher(lead, r)
|
||||
defer w.Close()
|
||||
|
||||
first := time.Now().Add(50 * time.Millisecond)
|
||||
_ = w.Update(first)
|
||||
|
||||
// Wait for stateChange + warning of the first cycle.
|
||||
waitForEvents(t, r, 2)
|
||||
|
||||
// Simulate a successful extend: brand new deadline.
|
||||
second := time.Now().Add(60 * time.Millisecond)
|
||||
_ = w.Update(second)
|
||||
|
||||
// 4 events total: stateChange, warning (first), stateChange, warning (second).
|
||||
events := waitForEvents(t, r, 4)
|
||||
if events[2].kind != stateChange {
|
||||
t.Fatalf("event[2] should be stateChange for the new deadline, got %+v", events[2])
|
||||
}
|
||||
if !events[3].isWarning() {
|
||||
t.Fatalf("event[3] should be a warning publish for the new deadline, got %+v", events[3])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateZeroAfterNonZeroClearsState(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(2 * time.Hour)
|
||||
_ = w.Update(d)
|
||||
waitForEvents(t, r, 1)
|
||||
|
||||
_ = w.Update(time.Time{})
|
||||
|
||||
events := waitForEvents(t, r, 2)
|
||||
if events[1].kind != stateChange {
|
||||
t.Fatalf("expected stateChange on clear, got %+v", events[1])
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("Deadline should be zero after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRejectsBeforeEpoch(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
good := time.Now().Add(time.Hour)
|
||||
if err := w.Update(good); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
|
||||
err := w.Update(time.Unix(-100, 0))
|
||||
if !errors.Is(err, ErrDeadlineBeforeEpoch) {
|
||||
t.Fatalf("want ErrDeadlineBeforeEpoch, got %v", err)
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("rejected pre-epoch update must clear deadline; got %v", w.Deadline())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRejectsTooFarFuture(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
good := time.Now().Add(time.Hour)
|
||||
if err := w.Update(good); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
|
||||
err := w.Update(time.Now().Add(50 * 365 * 24 * time.Hour))
|
||||
if !errors.Is(err, ErrDeadlineTooFarFuture) {
|
||||
t.Fatalf("want ErrDeadlineTooFarFuture, got %v", err)
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("rejected far-future update must clear deadline; got %v", w.Deadline())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInPastClearsDeadline(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
good := time.Now().Add(time.Hour)
|
||||
if err := w.Update(good); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
// Drain the stateChange from the seed.
|
||||
waitForEvents(t, r, 1)
|
||||
|
||||
err := w.Update(time.Now().Add(-1 * time.Hour))
|
||||
if !errors.Is(err, ErrDeadlineInPast) {
|
||||
t.Fatalf("want ErrDeadlineInPast, got %v", err)
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("in-past update must clear the deadline, got %v", w.Deadline())
|
||||
}
|
||||
events := waitForEvents(t, r, 2)
|
||||
if events[1].kind != stateChange {
|
||||
t.Fatalf("expected stateChange on clear, got %+v", events[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWithinSkewAccepted(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
// 5 seconds in the past is within the 30s Skew tolerance — accept it.
|
||||
d := time.Now().Add(-5 * time.Second)
|
||||
if err := w.Update(d); err != nil {
|
||||
t.Fatalf("within-skew Update should succeed, got %v", err)
|
||||
}
|
||||
if !w.Deadline().Equal(d) {
|
||||
t.Fatalf("expected deadline to be applied, got %v want %v", w.Deadline(), d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseSilencesUpdates(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
w.Close()
|
||||
|
||||
_ = w.Update(time.Now().Add(time.Hour))
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if got := r.snapshot(); len(got) != 0 {
|
||||
t.Fatalf("expected no events after Close, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCloseClearsRecorderDeadline pins the profile-switch fix: a watcher
|
||||
// holding a live deadline must zero the recorder on Close so the next
|
||||
// engine's watcher (and the UI reading the shared server-scoped recorder)
|
||||
// doesn't start out showing the previous session's stale "expires in".
|
||||
func TestCloseClearsRecorderDeadline(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r)
|
||||
|
||||
d := time.Now().Add(2 * time.Hour)
|
||||
if err := w.Update(d); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
if got := r.deadline(); !got.Equal(d) {
|
||||
t.Fatalf("recorder deadline after Update = %v, want %v", got, d)
|
||||
}
|
||||
|
||||
w.Close()
|
||||
|
||||
if got := r.deadline(); !got.IsZero() {
|
||||
t.Fatalf("recorder deadline after Close = %v, want zero", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCloseWithoutDeadlineLeavesRecorderUntouched guards the symmetric
|
||||
// case: closing a watcher that never held a deadline must not emit a
|
||||
// redundant clear (the recorder may legitimately hold a value written by
|
||||
// some other path; the watcher only owns what it set).
|
||||
func TestCloseWithoutDeadlineLeavesRecorderUntouched(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r)
|
||||
|
||||
w.Close()
|
||||
|
||||
if got := r.snapshot(); len(got) != 0 {
|
||||
t.Fatalf("expected no events from Close on an empty watcher, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalWarningFiresAfterRegularWarning(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
// Warning fires at deadline-80ms, final at deadline-30ms.
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
// Expect stateChange + warning + final-warning.
|
||||
events := waitForEvents(t, r, 3)
|
||||
|
||||
if countWhere(events, func(e event) bool { return e.kind == stateChange }) != 1 {
|
||||
t.Fatalf("expected exactly 1 stateChange, got %+v", events)
|
||||
}
|
||||
if countWhere(events, event.isWarning) != 1 {
|
||||
t.Fatalf("expected exactly 1 warning publish, got %+v", events)
|
||||
}
|
||||
if countWhere(events, event.isFinalWarning) != 1 {
|
||||
t.Fatalf("expected exactly 1 final-warning publish, got %+v", events)
|
||||
}
|
||||
|
||||
// Warning must precede final (same deadline, longer lead fires first).
|
||||
var wIdx, fIdx int
|
||||
for i, e := range events {
|
||||
switch {
|
||||
case e.isWarning():
|
||||
wIdx = i
|
||||
case e.isFinalWarning():
|
||||
fIdx = i
|
||||
}
|
||||
}
|
||||
if wIdx > fIdx {
|
||||
t.Fatalf("warning must publish before final-warning, got order %+v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDismissSuppressesFinalWarning(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
// Wait for the warning publish so we know we're inside the warning
|
||||
// window, then dismiss before the final timer would fire.
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if countWhere(r.snapshot(), event.isWarning) >= 1 {
|
||||
break
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
if countWhere(r.snapshot(), event.isWarning) < 1 {
|
||||
t.Fatalf("warning did not publish in time, events=%+v", r.snapshot())
|
||||
}
|
||||
|
||||
w.Dismiss()
|
||||
|
||||
// Now wait past when the final would have fired.
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
if n := countWhere(r.snapshot(), event.isFinalWarning); n != 0 {
|
||||
t.Fatalf("final-warning published after Dismiss(), events=%+v", r.snapshot())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDismissResetByNewDeadline(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
first := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(first)
|
||||
|
||||
// Dismiss against the first deadline.
|
||||
w.Dismiss()
|
||||
|
||||
// Replace with a fresh deadline before the first's timers complete.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
second := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(second)
|
||||
|
||||
// The second cycle must publish a final-warning (the dismiss state
|
||||
// did not carry over).
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if countWhere(r.snapshot(), event.isFinalWarning) >= 1 {
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
if countWhere(r.snapshot(), event.isFinalWarning) < 1 {
|
||||
t.Fatalf("final-warning did not publish on fresh deadline after Dismiss reset, events=%+v", r.snapshot())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDismissBeforeUpdateIsNoop(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
// No deadline tracked yet; Dismiss must be a no-op (no panic, no state).
|
||||
w.Dismiss()
|
||||
|
||||
d := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
// Final warning should still publish — Dismiss only acts on the current
|
||||
// deadline, and there was none at the time of the call.
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if countWhere(r.snapshot(), event.isFinalWarning) >= 1 {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("final-warning did not publish after no-op pre-Update Dismiss, events=%+v", r.snapshot())
|
||||
}
|
||||
@@ -257,6 +257,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
// On daemon shutdown / Down() the parent context is cancelled
|
||||
// and the dial fails with "context canceled". Wrapping that
|
||||
// into state would leave the snapshot stuck at Connecting+err
|
||||
// until the backoff loop wakes up — instead let the operation
|
||||
// return cleanly so the deferred state.Set(StatusIdle) takes
|
||||
// effect on the next iteration.
|
||||
if c.ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
||||
@@ -390,6 +399,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
// Seed the session-expiry deadline from the LoginResponse. Subsequent
|
||||
// changes flow in through SyncResponse and are applied in handleSync.
|
||||
engine.ApplySessionDeadline(loginResp.GetSessionExpiresAt())
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
@@ -430,7 +443,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
c.statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backOff)
|
||||
// Wrap the backoff with c.ctx so Down()/actCancel propagates into the
|
||||
// inter-attempt sleep — otherwise a 15s MaxInterval can keep the retry
|
||||
// loop alive long after the caller asked to give up, leaving the
|
||||
// status stream stuck at Connecting.
|
||||
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
|
||||
@@ -516,14 +516,6 @@ func (g *BundleGenerator) addConfig() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Surface the set of MDM-enforced keys so a support engineer reading
|
||||
// the bundle can tell which field values are user-set vs MDM-overridden.
|
||||
// Same semantics as the mDMManagedFields list returned by the
|
||||
// GetConfig RPC consumed by `netbird debug config`.
|
||||
if managed := g.internalConfig.Policy().ManagedKeys(); len(managed) > 0 {
|
||||
configContent.WriteString(fmt.Sprintf("MDMManagedFields: %v\n", managed))
|
||||
}
|
||||
|
||||
configReader := strings.NewReader(configContent.String())
|
||||
if err := g.addFileToZip(configReader, "config.txt"); err != nil {
|
||||
return fmt.Errorf("add config file to zip: %w", err)
|
||||
|
||||
@@ -843,7 +843,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
|
||||
}
|
||||
|
||||
mURL, _ := url.Parse("https://api.example.com:443")
|
||||
|
||||
@@ -482,7 +482,7 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
||||
// completely when every proxy peer is offline (the upstream may still
|
||||
// be reachable some other way, or the peerstore may be stale).
|
||||
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||
if len(records) < 2 {
|
||||
if len(records) == 0 {
|
||||
return records
|
||||
}
|
||||
d.mu.RLock()
|
||||
|
||||
@@ -2738,17 +2738,6 @@ func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||
connByIP: nil,
|
||||
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||
},
|
||||
{
|
||||
// A single answer is never filtered: dropping it would only
|
||||
// trigger the empty-answer escape hatch, so the fast path
|
||||
// returns it untouched.
|
||||
name: "single disconnected answer passes through",
|
||||
records: []nbdns.SimpleRecord{disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.11"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
||||
@@ -14,10 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// errNoSuitableAddress mirrors the unexported error string the net package
|
||||
// uses when a resolved host has no addresses of the requested family.
|
||||
const errNoSuitableAddress = "no suitable address found"
|
||||
|
||||
// GenerateRequestID creates a random 8-character hex string for request tracing.
|
||||
func GenerateRequestID() string {
|
||||
bytes := make([]byte, 4)
|
||||
@@ -130,14 +126,6 @@ func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint1
|
||||
}
|
||||
|
||||
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
|
||||
// The net package returns this AddrError when the host resolves but has
|
||||
// no addresses of the requested family. The domain exists, so answer
|
||||
// NODATA instead of SERVFAIL.
|
||||
var addrErr *net.AddrError
|
||||
if errors.As(err, &addrErr) && addrErr.Err == errNoSuitableAddress {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
return dns.RcodeServerFailure
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
package resutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockResolver struct {
|
||||
// results maps network ("ip4"/"ip6") to the lookup outcome.
|
||||
results map[string]mockLookup
|
||||
}
|
||||
|
||||
type mockLookup struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockResolver) LookupNetIP(_ context.Context, network, _ string) ([]netip.Addr, error) {
|
||||
res, ok := m.results[network]
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected network: " + network)
|
||||
}
|
||||
return res.ips, res.err
|
||||
}
|
||||
|
||||
func TestLookupIP_Success(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {ips: []netip.Addr{netip.MustParseAddr("::ffff:192.0.2.1")}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "successful lookup should return NOERROR")
|
||||
require.Len(t, result.IPs, 1, "should return the resolved address")
|
||||
assert.Equal(t, netip.MustParseAddr("192.0.2.1"), result.IPs[0], "v4-mapped address should be unmapped")
|
||||
}
|
||||
|
||||
func TestLookupIP_NoSuitableAddress(t *testing.T) {
|
||||
// The net package returns this AddrError when the host resolves but has
|
||||
// no addresses of the requested family (e.g. AAAA query for a v4-only
|
||||
// hosts file entry). The domain exists, so this is NODATA, not SERVFAIL.
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip6": {err: &net.AddrError{Err: "no suitable address found", Addr: "example.com."}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip6", "example.com.", dns.TypeAAAA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "no suitable address should map to NODATA")
|
||||
assert.Empty(t, result.IPs, "NODATA response should carry no addresses")
|
||||
}
|
||||
|
||||
// TestErrNoSuitableAddressMatchesNetPackage pins our copy of the error string
|
||||
// to what the net package actually emits. A literal IP of the wrong family
|
||||
// takes the same filterAddrList path as a resolved hostname, without network
|
||||
// access.
|
||||
func TestErrNoSuitableAddressMatchesNetPackage(t *testing.T) {
|
||||
_, err := (&net.Resolver{}).LookupNetIP(context.Background(), "ip6", "192.0.2.1")
|
||||
require.Error(t, err)
|
||||
|
||||
var addrErr *net.AddrError
|
||||
require.ErrorAs(t, err, &addrErr, "wrong-family lookup should return AddrError")
|
||||
assert.Equal(t, errNoSuitableAddress, addrErr.Err, "net package error string should match our constant")
|
||||
}
|
||||
|
||||
func TestLookupIP_OtherAddrError(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.AddrError{Err: "some other address problem", Addr: "example.com."}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "unrecognized AddrError should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestLookupIP_NotFoundNXDomain(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
"ip6": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeNameError, result.Rcode, "not found for both families should map to NXDOMAIN")
|
||||
}
|
||||
|
||||
func TestLookupIP_NotFoundNoData(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip6": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
"ip4": {ips: []netip.Addr{netip.MustParseAddr("192.0.2.1")}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip6", "example.com.", dns.TypeAAAA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "not found with the other family present should map to NODATA")
|
||||
}
|
||||
|
||||
func TestLookupIP_GenericError(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: errors.New("connection refused")},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "generic error should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.DNSError{Err: "server misbehaving", Name: "example.com.", IsTemporary: true}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
@@ -1,501 +0,0 @@
|
||||
//go:build privileged
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"local-resolver": handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = wgIface.Close()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = dnsServer.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
}
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("create and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -102,6 +104,481 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
"local-resolver": handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = wgIface.Close()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = dnsServer.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
}
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("create and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -260,6 +260,20 @@ type Engine struct {
|
||||
jobExecutorWG sync.WaitGroup
|
||||
|
||||
exposeManager *expose.Manager
|
||||
|
||||
sessionWatcher sessionDeadlineWatcher
|
||||
}
|
||||
|
||||
// sessionDeadlineWatcher is the engine-facing surface of the SSO session
|
||||
// expiry watcher. The concrete implementation (sessionwatch.Watcher) is wired
|
||||
// in via newSessionWatcher, which is build-tagged so the js/wasm build links a
|
||||
// no-op stub instead of pulling the full sessionwatch package (and its timer
|
||||
// machinery) into the binary — the wasm client never runs the engine's
|
||||
// session-warning flow.
|
||||
type sessionDeadlineWatcher interface {
|
||||
Update(deadline time.Time) error
|
||||
Dismiss()
|
||||
Close()
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -304,6 +318,17 @@ func NewEngine(
|
||||
updateManager: services.UpdateManager,
|
||||
syncStoreDir: config.StateDir,
|
||||
}
|
||||
// sessionWatcher keeps the SubscribeStatus consumers in sync with the
|
||||
// session expiry deadline. Deadline-change ticks come for free via
|
||||
// Status.SetSessionExpiresAt; the watcher exists to push a wake-up at
|
||||
// T-WarningLead and T-FinalWarningLead so the UI repaints the remaining
|
||||
// time / warning state even when nothing else changed, and to publish
|
||||
// two SystemEvents (the warning composition lives in sessionwatch so
|
||||
// the wire format stays owned by one package):
|
||||
// - T-WarningLead → interactive "Extend now / Dismiss" notification
|
||||
// - T-FinalWarningLead → auto-opened SessionAboutToExpire dialog,
|
||||
// suppressed when the user dismissed the earlier warning
|
||||
engine.sessionWatcher = newSessionWatcher(engine.statusRecorder)
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
return engine
|
||||
@@ -344,6 +369,10 @@ func (e *Engine) Stop() error {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
|
||||
if e.sessionWatcher != nil {
|
||||
e.sessionWatcher.Close()
|
||||
}
|
||||
|
||||
if e.updateManager != nil {
|
||||
e.updateManager.SetDownloadOnly()
|
||||
}
|
||||
@@ -876,98 +905,81 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
e.ApplySessionDeadline(update.GetSessionExpiresAt())
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
wCfg := update.GetNetbirdConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
// NetworkMap != nil, checks present -> apply the received checks
|
||||
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
|
||||
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
|
||||
// leave the previously applied checks untouched
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
err = e.updateSTUNs(wCfg.GetStuns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
err = e.handleRelayUpdate(wCfg.GetRelay())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.handleFlowUpdate(wCfg.GetFlow())
|
||||
if err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.persistSyncResponse(update)
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Persist sync response only after updateNetworkMap accepted and applied the update,
|
||||
// so GetLatestSyncResponse() never returns state the engine did not actually apply.
|
||||
// Done under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
|
||||
// the whole Set so the store cannot be cleared (disabled / engine close)
|
||||
// mid-call and have this write resurrect a file that was just removed.
|
||||
e.syncRespMux.RLock()
|
||||
if e.syncStore != nil {
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
} else {
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
}
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
|
||||
// which is the case for sync updates carrying only a network map.
|
||||
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
||||
if wCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
|
||||
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// persistSyncResponse stores the full sync response so it can be restored on the next
|
||||
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
|
||||
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
||||
// engine close) mid-call and have this write resurrect a file that was just removed.
|
||||
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
||||
e.syncRespMux.RLock()
|
||||
defer e.syncRespMux.RUnlock()
|
||||
|
||||
if e.syncStore == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
|
||||
}
|
||||
|
||||
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
if update != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
@@ -1175,7 +1187,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
||||
ClientMetrics: e.clientMetrics,
|
||||
DaemonVersion: version.NetbirdVersion(),
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(true)
|
||||
e.RunHealthProbes(e.ctx, true)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2072,7 +2084,20 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
//
|
||||
// ctx scopes the (potentially slow) STUN/TURN probing: a caller that gives up —
|
||||
// e.g. a Status RPC whose client disconnected — cancels its ctx and the probe
|
||||
// returns instead of running to its per-component timeout. The engine's own
|
||||
// lifetime ctx still applies independently, so an engine shutdown aborts the
|
||||
// probe even if the caller's ctx is context.Background().
|
||||
func (e *Engine) RunHealthProbes(ctx context.Context, waitForResult bool) bool {
|
||||
// Tie the caller's ctx to the engine lifetime: either cancelling aborts
|
||||
// the probe below.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
stop := context.AfterFunc(e.ctx, cancel)
|
||||
defer stop()
|
||||
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
signalHealthy := e.signal.IsHealthy()
|
||||
@@ -2095,9 +2120,9 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
if runtime.GOOS != "js" {
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
results = e.probeStunTurn.ProbeAll(ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
|
||||
108
client/internal/engine_authsession.go
Normal file
108
client/internal/engine_authsession.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/internal/auth/sessionwatch"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
// ApplySessionDeadline propagates the absolute SSO session deadline carried on
|
||||
// LoginResponse / SyncResponse to both the watcher (for the edge-triggered
|
||||
// warning) and the status recorder (for the SubscribeStatus / Status RPC
|
||||
// snapshot the UI consumes).
|
||||
//
|
||||
// The wire field is 3-state:
|
||||
// - nil → snapshot carries no info; keep the
|
||||
// previously-anchored deadline (no-op)
|
||||
// - explicit zero (s=0, n=0) → peer is not SSO-registered or expiry is
|
||||
// disabled; clear both sinks
|
||||
// - valid timestamp → new deadline; arm watcher, expose on
|
||||
// status recorder
|
||||
//
|
||||
// Deadline sanity-checks live in sessionwatch.Watcher.Update. Any rejected
|
||||
// value is treated as a clear on both sinks: the alternative — leaving the
|
||||
// previously-known deadline in place — risks the UI confidently displaying
|
||||
// a stale "expires in X" while the server has actually invalidated it.
|
||||
func (e *Engine) ApplySessionDeadline(ts *timestamppb.Timestamp) {
|
||||
if ts == nil {
|
||||
return
|
||||
}
|
||||
var deadline time.Time
|
||||
// Explicit zero (seconds=0 AND nanos=0) is the sentinel for "disabled".
|
||||
// Everything else flows through Watcher.Update, whose sanity-checks
|
||||
// reject out-of-range / pre-epoch / far-future / too-stale values and
|
||||
// clear on rejection.
|
||||
if ts.GetSeconds() != 0 || ts.GetNanos() != 0 {
|
||||
deadline = ts.AsTime().UTC()
|
||||
}
|
||||
if e.sessionWatcher == nil {
|
||||
return
|
||||
}
|
||||
// Watcher.Update owns the propagation to the status recorder (the
|
||||
// SubscribeStatus / Status snapshot the UI reads): a set writes the
|
||||
// deadline, a clear or a sanity-check rejection writes the zero value.
|
||||
// Keeping a single writer is what stops the recorder from drifting out
|
||||
// of sync with the warning timers.
|
||||
if err := e.sessionWatcher.Update(deadline); err != nil {
|
||||
log.Errorf("auth session deadline rejected: %v, clearing", err)
|
||||
e.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_ERROR,
|
||||
cProto.SystemEvent_AUTHENTICATION,
|
||||
"session deadline rejected",
|
||||
"",
|
||||
map[string]string{sessionwatch.MetaSessionDeadlineRejected: err.Error()},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// DismissSessionWarning records the user's "Dismiss" click on the
|
||||
// T-WarningLead interactive notification and suppresses the upcoming
|
||||
// T-FinalWarningLead fallback for the current deadline. No-op when the
|
||||
// watcher is not running or holds no deadline.
|
||||
func (e *Engine) DismissSessionWarning() {
|
||||
if e.sessionWatcher == nil {
|
||||
return
|
||||
}
|
||||
e.sessionWatcher.Dismiss()
|
||||
}
|
||||
|
||||
// ExtendAuthSession asks the management server to refresh the SSO session
|
||||
// expiry deadline using the supplied JWT, then mirrors the new deadline into
|
||||
// the daemon's state. The tunnel is untouched; no resync, no reconnect.
|
||||
//
|
||||
// Returns the new absolute UTC deadline (or zero time when the server
|
||||
// reports the peer is not eligible for extension).
|
||||
func (e *Engine) ExtendAuthSession(ctx context.Context, jwtToken string) (time.Time, error) {
|
||||
if jwtToken == "" {
|
||||
return time.Time{}, errors.New("jwt token is required")
|
||||
}
|
||||
if e.mgmClient == nil {
|
||||
return time.Time{}, errors.New("management client is not initialised")
|
||||
}
|
||||
|
||||
info, err := system.GetInfoWithChecks(ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to collect system info for session extend: %v", err)
|
||||
info = system.GetInfo(ctx)
|
||||
}
|
||||
|
||||
resp, err := e.mgmClient.ExtendAuthSession(info, jwtToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("extend auth session on management: %w", err)
|
||||
}
|
||||
|
||||
e.ApplySessionDeadline(resp.GetSessionExpiresAt())
|
||||
|
||||
if resp.GetSessionExpiresAt().IsValid() {
|
||||
return resp.GetSessionExpiresAt().AsTime().UTC(), nil
|
||||
}
|
||||
return time.Time{}, nil
|
||||
}
|
||||
@@ -1,565 +0,0 @@
|
||||
//go:build privileged
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.21/24"},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||
},
|
||||
}
|
||||
|
||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||
networkMap := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// SSH server is enabled, therefore SSH config should be applied
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 8,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 9,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_Sync(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||
updates <- &mgmtProto.SyncResponse{
|
||||
NetworkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 10,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second * 2)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout while waiting for test to finish")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
mu := sync.Mutex{}
|
||||
engines := []*Engine{}
|
||||
numPeers := 10
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numPeers)
|
||||
// create and start peers
|
||||
for i := 0; i < numPeers; i++ {
|
||||
j := i
|
||||
go func() {
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
engines = append(engines, engine)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until all have been created and started
|
||||
wg.Wait()
|
||||
if len(engines) != numPeers {
|
||||
t.Fatal("not all peers was started")
|
||||
}
|
||||
// check whether all the peer have expected peers connected
|
||||
|
||||
expectedConnected := numPeers * (numPeers - 1)
|
||||
|
||||
// adjust according to timeouts
|
||||
timeout := 50 * time.Second
|
||||
timeoutChan := time.After(timeout)
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
totalConnected := 0
|
||||
for _, engine := range engines {
|
||||
totalConnected += getConnectedPeers(engine)
|
||||
}
|
||||
if totalConnected == expectedConnected {
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
break loop
|
||||
}
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
}
|
||||
}
|
||||
// cleanup test
|
||||
for n, peerEngine := range engines {
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||
errStop := peerEngine.mgmClient.Close()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||
}
|
||||
errStop = peerEngine.Stop()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifaceName string
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
conf := &EngineConfig{
|
||||
WgIfaceName: ifaceName,
|
||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||
WgPrivateKey: key,
|
||||
WgPort: wgPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
78
client/internal/engine_session_deadline_test.go
Normal file
78
client/internal/engine_session_deadline_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth/sessionwatch"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
// TestApplySessionDeadline_ThreeState pins down the 3-state semantics of the
|
||||
// wire field carried on LoginResponse / SyncResponse:
|
||||
//
|
||||
// - nil pointer → no info; previously-anchored deadline survives
|
||||
// - explicit zero value → "expiry disabled" sentinel; both sinks cleared
|
||||
// - valid future timestamp → new deadline propagated to both sinks
|
||||
func TestApplySessionDeadline_ThreeState(t *testing.T) {
|
||||
newEngine := func() *Engine {
|
||||
recorder := peer.NewRecorder("")
|
||||
return &Engine{
|
||||
statusRecorder: recorder,
|
||||
sessionWatcher: sessionwatch.New(recorder),
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("valid timestamp sets deadline on both sinks", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
deadline := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
|
||||
e.ApplySessionDeadline(timestamppb.New(deadline))
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(deadline),
|
||||
"status recorder should hold the new deadline")
|
||||
})
|
||||
|
||||
t.Run("nil is a no-op and preserves previous deadline", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
e.ApplySessionDeadline(timestamppb.New(seeded))
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
|
||||
|
||||
e.ApplySessionDeadline(nil)
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded),
|
||||
"nil snapshot must not disturb the existing deadline")
|
||||
})
|
||||
|
||||
t.Run("explicit zero clears a previously-anchored deadline", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
e.ApplySessionDeadline(timestamppb.New(seeded))
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
|
||||
|
||||
// Explicit zero Timestamp{} (seconds=0, nanos=0) is the
|
||||
// "expiry disabled / not SSO" sentinel.
|
||||
e.ApplySessionDeadline(×tamppb.Timestamp{})
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().IsZero(),
|
||||
"explicit zero sentinel must clear the deadline")
|
||||
})
|
||||
|
||||
t.Run("invalid timestamp clears the deadline", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
e.ApplySessionDeadline(timestamppb.New(seeded))
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
|
||||
|
||||
// Out-of-range nanos → IsValid()==false; same-meaning as the
|
||||
// disabled sentinel for downstream sinks.
|
||||
e.ApplySessionDeadline(×tamppb.Timestamp{Seconds: 1, Nanos: -1})
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().IsZero(),
|
||||
"invalid timestamp must clear the deadline")
|
||||
})
|
||||
}
|
||||
16
client/internal/engine_sessionwatch.go
Normal file
16
client/internal/engine_sessionwatch.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/auth/sessionwatch"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
// newSessionWatcher returns the real SSO session expiry watcher for every
|
||||
// non-wasm build. The js/wasm build gets a no-op stub from
|
||||
// engine_sessionwatch_js.go so the sessionwatch package (and its timer
|
||||
// machinery) never links into the wasm binary.
|
||||
func newSessionWatcher(recorder *peer.Status) sessionDeadlineWatcher {
|
||||
return sessionwatch.New(recorder)
|
||||
}
|
||||
44
client/internal/engine_sessionwatch_js.go
Normal file
44
client/internal/engine_sessionwatch_js.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
// noopSessionWatcher is the js/wasm stand-in for sessionwatch.Watcher. The
|
||||
// wasm client never runs the engine's session-warning flow (the interactive
|
||||
// T-WarningLead notification and the T-FinalWarningLead fallback dialog live
|
||||
// in the desktop UI), so linking the full sessionwatch package (timers, event
|
||||
// composition) would only bloat the binary.
|
||||
//
|
||||
// It still mirrors the deadline into the status recorder so the SubscribeStatus
|
||||
// / Status snapshot the UI consumes stays correct — only the timer-driven
|
||||
// warnings are dropped.
|
||||
type noopSessionWatcher struct {
|
||||
recorder *peer.Status
|
||||
}
|
||||
|
||||
func newSessionWatcher(recorder *peer.Status) sessionDeadlineWatcher {
|
||||
return noopSessionWatcher{recorder: recorder}
|
||||
}
|
||||
|
||||
// Update mirrors the real watcher's recorder propagation without the timers or
|
||||
// sanity-check sentinels: a valid deadline is exposed on the status snapshot,
|
||||
// the zero time clears it.
|
||||
func (w noopSessionWatcher) Update(deadline time.Time) error {
|
||||
if w.recorder != nil {
|
||||
w.recorder.SetSessionExpiresAt(deadline)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (noopSessionWatcher) Dismiss() {
|
||||
// No-op: only suppresses the timer-driven final-warning, which this stub never arms.
|
||||
}
|
||||
|
||||
func (noopSessionWatcher) Close() {
|
||||
// No-op: no timers to stop and no state to unwind; the recorder is cleared via Update(zero).
|
||||
}
|
||||
@@ -6,18 +6,37 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -31,7 +50,18 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
@@ -39,9 +69,25 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
@@ -188,6 +234,129 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.21/24"},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||
},
|
||||
}
|
||||
|
||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||
networkMap := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// SSH server is enabled, therefore SSH config should be applied
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 8,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 9,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||
// Test that SSH server start/stop logic works based on config
|
||||
engine := &Engine{
|
||||
@@ -462,6 +631,97 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_Sync(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||
updates <- &mgmtProto.SyncResponse{
|
||||
NetworkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 10,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second * 2)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout while waiting for test to finish")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -845,6 +1105,104 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
mu := sync.Mutex{}
|
||||
engines := []*Engine{}
|
||||
numPeers := 10
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numPeers)
|
||||
// create and start peers
|
||||
for i := 0; i < numPeers; i++ {
|
||||
j := i
|
||||
go func() {
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
engines = append(engines, engine)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until all have been created and started
|
||||
wg.Wait()
|
||||
if len(engines) != numPeers {
|
||||
t.Fatal("not all peers was started")
|
||||
}
|
||||
// check whether all the peer have expected peers connected
|
||||
|
||||
expectedConnected := numPeers * (numPeers - 1)
|
||||
|
||||
// adjust according to timeouts
|
||||
timeout := 50 * time.Second
|
||||
timeoutChan := time.After(timeout)
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
totalConnected := 0
|
||||
for _, engine := range engines {
|
||||
totalConnected += getConnectedPeers(engine)
|
||||
}
|
||||
if totalConnected == expectedConnected {
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
break loop
|
||||
}
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
}
|
||||
}
|
||||
// cleanup test
|
||||
for n, peerEngine := range engines {
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||
errStop := peerEngine.mgmClient.Close()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||
}
|
||||
errStop = peerEngine.Stop()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
ifaceList, err := net.Interfaces()
|
||||
if err != nil {
|
||||
@@ -1168,6 +1526,187 @@ func TestCompareNetIPLists(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifaceName string
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
conf := &EngineConfig{
|
||||
WgIfaceName: ifaceName,
|
||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||
WgPrivateKey: key,
|
||||
WgPort: wgPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
|
||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
||||
t.Helper()
|
||||
b, err := netiputil.EncodePrefix(p)
|
||||
|
||||
@@ -26,6 +26,7 @@ type connStatusInputs struct {
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -191,22 +193,27 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
// every private-service request) don't contend against each other.
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
// sessionExpiresAt is the absolute UTC instant at which the peer's SSO
|
||||
// session expires. Zero when the peer is not SSO-tracked or login
|
||||
// expiration is disabled. Populated from management LoginResponse /
|
||||
// SyncResponse and exposed via the daemon's Status / SubscribeStatus RPC
|
||||
// so the UI can show remaining time without itself talking to mgm.
|
||||
sessionExpiresAt time.Time
|
||||
nsGroupStates []NSGroupState
|
||||
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
||||
lazyConnectionEnabled bool
|
||||
@@ -222,6 +229,21 @@ type Status struct {
|
||||
eventStreams map[string]chan *proto.SystemEvent
|
||||
eventQueue *EventQueue
|
||||
|
||||
// stateChangeStreams fan-out connection-state changes (connected /
|
||||
// disconnected / connecting / address change / peers list change) to
|
||||
// every active SubscribeStatus gRPC stream. Each subscriber gets a
|
||||
// buffered chan; the notifier non-blockingly pings them so a slow
|
||||
// consumer can never stall the daemon.
|
||||
stateChangeMux sync.Mutex
|
||||
stateChangeStreams map[string]chan struct{}
|
||||
|
||||
// networksRevision bumps whenever the routed-networks set or their
|
||||
// selected state changes (driven by the route manager). Surfaced in the
|
||||
// status snapshot so the UI can fingerprint on it and re-fetch
|
||||
// ListNetworks only on a real change. Atomic so the snapshot builder can
|
||||
// read it without taking mux.
|
||||
networksRevision atomic.Uint64
|
||||
|
||||
ingressGwMgr *ingressgw.Manager
|
||||
|
||||
routeIDLookup routeIDLookup
|
||||
@@ -232,10 +254,10 @@ type Status struct {
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
ipToKey: make(map[string]string),
|
||||
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
|
||||
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||
eventQueue: NewEventQueue(eventQueueSize),
|
||||
stateChangeStreams: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
@@ -284,12 +306,6 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
if ipv6 != "" {
|
||||
d.ipToKey[ipv6] = peerPubKey
|
||||
}
|
||||
if ip != "" {
|
||||
d.ipToKey[ip] = peerPubKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -319,22 +335,28 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Only
|
||||
// active peers are matched; peers moved into the offline slice by
|
||||
// ReplaceOfflinePeers are intentionally treated as unknown.
|
||||
// address so dual-stack peers are reachable on either family. Searches
|
||||
// both d.peers and d.offlinePeers — peers that have been moved into
|
||||
// the offline slice by ReplaceOfflinePeers are still part of the
|
||||
// account's roster and callers (DNS filter, embed.Client.IdentityForIP)
|
||||
// need to recognise them rather than treating them as unknown. Returns
|
||||
// the zero State and false when no peer matches or the input is empty.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
key, ok := d.ipToKey[ip]
|
||||
if !ok {
|
||||
return State{}, false
|
||||
|
||||
for _, state := range d.peers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
state, ok := d.peers[key]
|
||||
if ok {
|
||||
return state, true
|
||||
for _, state := range d.offlinePeers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
@@ -344,18 +366,12 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
p, ok := d.peers[peerPubKey]
|
||||
_, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return errors.New("no peer with to remove")
|
||||
}
|
||||
|
||||
delete(d.peers, peerPubKey)
|
||||
if mappedKey, exists := d.ipToKey[p.IP]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IP)
|
||||
}
|
||||
if mappedKey, exists := d.ipToKey[p.IPv6]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IPv6)
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
return nil
|
||||
}
|
||||
@@ -400,6 +416,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -425,6 +442,7 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -450,6 +468,7 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -499,6 +518,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -535,6 +555,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -570,6 +591,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -608,6 +630,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -701,6 +724,7 @@ func (d *Status) FinishPeerListModifications() {
|
||||
for _, rd := range dispatches {
|
||||
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription {
|
||||
@@ -759,6 +783,41 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// SetSessionExpiresAt records the absolute UTC instant at which the peer's
|
||||
// SSO session is set to expire. Pass the zero value to clear (e.g. when the
|
||||
// management server stops publishing a deadline because login expiration was
|
||||
// disabled or the peer is not SSO-tracked). Same-value updates are no-ops;
|
||||
// real changes fan out via notifyStateChange so SubscribeStatus consumers
|
||||
// pick up the new deadline on their next read.
|
||||
func (d *Status) SetSessionExpiresAt(deadline time.Time) {
|
||||
d.mux.Lock()
|
||||
if d.sessionExpiresAt.Equal(deadline) {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.sessionExpiresAt = deadline
|
||||
d.mux.Unlock()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// GetSessionExpiresAt returns the most recently recorded SSO session deadline,
|
||||
// or the zero value when no deadline is tracked. A deadline that has already
|
||||
// slipped into the past reports as "none": once the session has expired it is
|
||||
// no longer a meaningful countdown, and the sessionwatch.Watcher does not
|
||||
// arm a timer at the deadline itself to clear it (only the two pre-expiry
|
||||
// warnings). Without this guard the UI would keep painting a stale
|
||||
// "expires in …" against a moment that has passed until the next login,
|
||||
// extend, or teardown rewrote the value.
|
||||
func (d *Status) GetSessionExpiresAt() time.Time {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
if !d.sessionExpiresAt.IsZero() && d.sessionExpiresAt.Before(time.Now()) {
|
||||
return time.Time{}
|
||||
}
|
||||
return d.sessionExpiresAt
|
||||
}
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
@@ -827,11 +886,19 @@ func (d *Status) CleanLocalPeerState() {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
// Health checks re-mark the same state on every probe; skip the fan-out
|
||||
// when nothing actually changed so we don't flood SubscribeStatus
|
||||
// consumers with identical snapshots.
|
||||
if !d.managementState && errors.Is(d.managementError, err) {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.managementState = false
|
||||
d.managementError = err
|
||||
mgm := d.managementState
|
||||
@@ -839,11 +906,16 @@ func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Lock()
|
||||
if d.managementState && d.managementError == nil {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.managementState = true
|
||||
d.managementError = nil
|
||||
mgm := d.managementState
|
||||
@@ -851,6 +923,7 @@ func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
@@ -884,6 +957,10 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
if !d.signalState && errors.Is(d.signalError, err) {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.signalState = false
|
||||
d.signalError = err
|
||||
mgm := d.managementState
|
||||
@@ -891,11 +968,16 @@ func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Lock()
|
||||
if d.signalState && d.signalError == nil {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.signalState = true
|
||||
d.signalError = nil
|
||||
mgm := d.managementState
|
||||
@@ -903,6 +985,7 @@ func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
@@ -1100,16 +1183,19 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
// ClientStart will notify all listeners about the new service state
|
||||
func (d *Status) ClientStart() {
|
||||
d.notifier.clientStart()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// ClientStop will notify all listeners about the new service state
|
||||
func (d *Status) ClientStop() {
|
||||
d.notifier.clientStop()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// ClientTeardown will notify all listeners about the service is under teardown
|
||||
func (d *Status) ClientTeardown() {
|
||||
d.notifier.clientTearDown()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// SetConnectionListener set a listener to the notifier
|
||||
@@ -1251,6 +1337,82 @@ func (d *Status) GetEventHistory() []*proto.SystemEvent {
|
||||
return d.eventQueue.GetAll()
|
||||
}
|
||||
|
||||
// SubscribeToStateChanges hands back a channel that receives a tick on
|
||||
// every connection-state change (connected / disconnected / connecting /
|
||||
// address change / peers-list change). The channel is buffered to one
|
||||
// pending tick so a coalesced burst still wakes the consumer exactly
|
||||
// once. Pass the returned id to UnsubscribeFromStateChanges to detach.
|
||||
func (d *Status) SubscribeToStateChanges() (string, <-chan struct{}) {
|
||||
d.stateChangeMux.Lock()
|
||||
defer d.stateChangeMux.Unlock()
|
||||
|
||||
id := uuid.New().String()
|
||||
ch := make(chan struct{}, 1)
|
||||
d.stateChangeStreams[id] = ch
|
||||
return id, ch
|
||||
}
|
||||
|
||||
// UnsubscribeFromStateChanges releases a SubscribeToStateChanges channel
|
||||
// and closes it so any consumer goroutine selecting on the channel
|
||||
// unblocks cleanly.
|
||||
func (d *Status) UnsubscribeFromStateChanges(id string) {
|
||||
d.stateChangeMux.Lock()
|
||||
defer d.stateChangeMux.Unlock()
|
||||
|
||||
if ch, ok := d.stateChangeStreams[id]; ok {
|
||||
close(ch)
|
||||
delete(d.stateChangeStreams, id)
|
||||
}
|
||||
}
|
||||
|
||||
// notifyStateChange wakes every SubscribeToStateChanges subscriber. Drops
|
||||
// the tick if a subscriber's buffer is full — by definition the consumer
|
||||
// is already going to fetch the latest snapshot, so multiple pending ticks
|
||||
// would be redundant.
|
||||
func (d *Status) notifyStateChange() {
|
||||
if _, file, line, ok := runtime.Caller(1); ok {
|
||||
log.Infof("--- notifyStateChange from %s:%d", file, line)
|
||||
}
|
||||
d.stateChangeMux.Lock()
|
||||
defer d.stateChangeMux.Unlock()
|
||||
|
||||
for _, ch := range d.stateChangeStreams {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyStateChange is the public wake-the-subscribers entry point used by
|
||||
// callers that mutate state outside the peer recorder — most importantly
|
||||
// the connect-state machine, which writes StatusNeedsLogin into the
|
||||
// shared contextState (client/internal/state.go) without touching any
|
||||
// recorder field. Without this push the SubscribeStatus stream stays on
|
||||
// the previous snapshot until an unrelated peer/management/signal
|
||||
// change happens to fire notifyStateChange, leaving the UI's status
|
||||
// out of sync with the daemon.
|
||||
func (d *Status) NotifyStateChange() {
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// BumpNetworksRevision increments the routed-networks revision and wakes every
|
||||
// SubscribeStatus subscriber. The route manager calls it when a network map
|
||||
// changes the available routes or when a selection is applied — the peer
|
||||
// status itself only records actively-routed (chosen) networks, so without
|
||||
// this bump a candidate route appearing/disappearing would never reach the UI.
|
||||
func (d *Status) BumpNetworksRevision() {
|
||||
d.networksRevision.Add(1)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// GetNetworksRevision returns the current routed-networks revision, surfaced in
|
||||
// the status snapshot so the UI can detect route/selection changes (see
|
||||
// BumpNetworksRevision).
|
||||
func (d *Status) GetNetworksRevision() uint64 {
|
||||
return d.networksRevision.Load()
|
||||
}
|
||||
|
||||
func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -90,11 +90,12 @@ func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_IgnoresOfflinePeers documents that peers
|
||||
// moved into the offline slice via ReplaceOfflinePeers are intentionally
|
||||
// not resolvable by IP: only active peers can carry traffic, so callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) treat them as unknown.
|
||||
func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
|
||||
// TestStatus_PeerStateByIP_MatchesOfflinePeers covers peers that have
|
||||
// been moved into the offline slice via ReplaceOfflinePeers. Callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) need to treat them as known
|
||||
// rather than unknown — otherwise authentication / DNS filtering treats
|
||||
// known-but-offline peers as foreign IPs.
|
||||
func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
@@ -102,31 +103,13 @@ func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
|
||||
{PubKey: "pk-offline", FQDN: "offline.netbird", IP: "100.64.0.20", IPv6: "fd00::20"},
|
||||
})
|
||||
|
||||
_, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.False(ok, "offline peer must not resolve by IPv4 tunnel address")
|
||||
state, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.True(ok, "offline peer must resolve by IPv4 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "matching state must carry the offline peer's pub key")
|
||||
|
||||
_, ok = status.PeerStateByIP("fd00::20")
|
||||
req.False(ok, "offline peer must not resolve by IPv6 tunnel address")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_RemovedPeer verifies RemovePeer drops the
|
||||
// IP index entries for both address families.
|
||||
func TestStatus_PeerStateByIP_RemovedPeer(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||
|
||||
_, ok := status.PeerStateByIP("100.64.0.10")
|
||||
req.True(ok, "active peer must resolve before removal")
|
||||
|
||||
req.NoError(status.RemovePeer("pk-1"))
|
||||
|
||||
_, ok = status.PeerStateByIP("100.64.0.10")
|
||||
req.False(ok, "removed peer must not resolve by IPv4 tunnel address")
|
||||
|
||||
_, ok = status.PeerStateByIP("fd00::1")
|
||||
req.False(ok, "removed peer must not resolve by IPv6 tunnel address")
|
||||
state, ok = status.PeerStateByIP("fd00::20")
|
||||
req.True(ok, "offline peer must resolve by IPv6 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "IPv6 match must carry the offline peer's pub key")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
@@ -314,3 +297,39 @@ func TestGetFullStatus(t *testing.T) {
|
||||
assert.Equal(t, signalState, fullStatus.SignalState, "signal status should be equal")
|
||||
assert.ElementsMatch(t, []State{peerState1, peerState2}, fullStatus.Peers, "peers states should match")
|
||||
}
|
||||
|
||||
// notified reports whether a state-change tick is pending on ch, draining it.
|
||||
func notified(ch <-chan struct{}) bool {
|
||||
select {
|
||||
case <-ch:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkServerStateDoesNotNotifyWhenUnchanged(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
_, ch := status.SubscribeToStateChanges()
|
||||
|
||||
// First transition is a real change and must notify.
|
||||
status.MarkManagementConnected()
|
||||
require.True(t, notified(ch), "first connect should notify")
|
||||
|
||||
// Re-marking the same state must not notify again.
|
||||
status.MarkManagementConnected()
|
||||
assert.False(t, notified(ch), "redundant connect should not notify")
|
||||
|
||||
// Same for signal.
|
||||
status.MarkSignalConnected()
|
||||
require.True(t, notified(ch), "first signal connect should notify")
|
||||
status.MarkSignalConnected()
|
||||
assert.False(t, notified(ch), "redundant signal connect should not notify")
|
||||
|
||||
// A genuine change (disconnect with an error) notifies again.
|
||||
err := errors.New("boom")
|
||||
status.MarkManagementDisconnected(err)
|
||||
require.True(t, notified(ch), "disconnect should notify")
|
||||
status.MarkManagementDisconnected(err)
|
||||
assert.False(t, notified(ch), "redundant disconnect should not notify")
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -58,10 +57,6 @@ var DefaultInterfaceBlacklist = []string{
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
// loadMDMPolicy is the package-level indirection used by apply() to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
@@ -179,23 +174,6 @@ type Config struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// policy is the MDM policy that produced the currently-set values for
|
||||
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
|
||||
// and reset on every apply() invocation. Never persisted to disk.
|
||||
// Callers query enforcement state via Policy() and the mdm.Policy API
|
||||
// (HasKey, ManagedKeys, IsEmpty).
|
||||
policy *mdm.Policy `json:"-"`
|
||||
}
|
||||
|
||||
// Policy returns the MDM policy applied to this Config. Returns a non-nil
|
||||
// empty Policy when MDM enforcement is inactive; callers can always invoke
|
||||
// HasKey / ManagedKeys / IsEmpty without a nil check.
|
||||
func (config *Config) Policy() *mdm.Policy {
|
||||
if config == nil || config.policy == nil {
|
||||
return mdm.NewPolicy(nil)
|
||||
}
|
||||
return config.policy
|
||||
}
|
||||
|
||||
var ConfigDirOverride string
|
||||
@@ -634,93 +612,10 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
// MDM is the last override layer: any key present in the policy
|
||||
// supersedes defaults, on-disk config, env vars and CLI input.
|
||||
config.applyMDMPolicy(loadMDMPolicy())
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// applyMDMPolicy overlays MDM-supplied values on top of the resolved Config.
|
||||
// The provided Policy is also stored on the Config so callers can later query
|
||||
// which fields are enforced. Invalid values (e.g. malformed URLs) are logged
|
||||
// and skipped to avoid bricking the client; the field keeps its previous
|
||||
// resolved value but is still marked as managed (Policy.HasKey returns true
|
||||
// for the key, so per-field rejection of user writes still applies).
|
||||
func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
|
||||
config.policy = policy
|
||||
if policy.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper: log the application of a single MDM-managed key. Values for
|
||||
// keys in mdm.SecretKeys are redacted.
|
||||
logApplied := func(key string, displayValue any) {
|
||||
if _, secret := mdm.SecretKeys[key]; secret {
|
||||
log.Infof("MDM override %s = ********** (secret)", key)
|
||||
return
|
||||
}
|
||||
log.Infof("MDM override %s = %v", key, displayValue)
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyManagementURL); ok {
|
||||
if u, err := parseURL("Management URL", v); err != nil {
|
||||
log.Warnf("MDM management URL %q invalid: %v; keeping previous value", v, err)
|
||||
} else {
|
||||
config.ManagementURL = u
|
||||
logApplied(mdm.KeyManagementURL, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyPreSharedKey); ok {
|
||||
// Defensive: refuse the redaction mask in case it round-tripped
|
||||
// through a manifest by mistake.
|
||||
if !isPreSharedKeyHidden(&v) {
|
||||
config.PreSharedKey = v
|
||||
logApplied(mdm.KeyPreSharedKey, "")
|
||||
}
|
||||
}
|
||||
|
||||
// applyBool collapses the per-key "read + set + log" boilerplate
|
||||
// for every plain bool MDM key into a single helper. Keeps the
|
||||
// outer function's cognitive complexity below SonarCube's
|
||||
// threshold; functional behaviour is identical to the inlined
|
||||
// branches it replaces.
|
||||
applyBool := func(key string, setter func(bool)) {
|
||||
v, ok := policy.GetBool(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
setter(v)
|
||||
logApplied(key, v)
|
||||
}
|
||||
|
||||
applyBool(mdm.KeyAllowServerSSH, func(v bool) { bv := v; config.ServerSSHAllowed = &bv })
|
||||
applyBool(mdm.KeyDisableClientRoutes, func(v bool) { config.DisableClientRoutes = v })
|
||||
applyBool(mdm.KeyDisableServerRoutes, func(v bool) { config.DisableServerRoutes = v })
|
||||
applyBool(mdm.KeyBlockInbound, func(v bool) { config.BlockInbound = v })
|
||||
applyBool(mdm.KeyDisableAutoConnect, func(v bool) { config.DisableAutoConnect = v })
|
||||
applyBool(mdm.KeyRosenpassEnabled, func(v bool) { config.RosenpassEnabled = v })
|
||||
applyBool(mdm.KeyRosenpassPermissive, func(v bool) { config.RosenpassPermissive = v })
|
||||
|
||||
if v, ok := policy.GetInt(mdm.KeyWireguardPort); ok {
|
||||
// REG_DWORD is 32-bit; UDP port range is 1-65535. Clamp at the
|
||||
// upper bound and reject obviously-invalid values to avoid the
|
||||
// engine binding to an unusable port if the admin pushes garbage.
|
||||
if v >= 1 && v <= 65535 {
|
||||
config.WgPort = int(v)
|
||||
logApplied(mdm.KeyWireguardPort, v)
|
||||
} else {
|
||||
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseURL parses and validates the URL for the named service. The URL
|
||||
// must use the http or https scheme; if no port is present, ":443" is
|
||||
// appended for https or ":80" for http. The serviceName parameter is
|
||||
// used to contextualise error messages. On success returns the parsed
|
||||
// *url.URL; on failure returns a non-nil error.
|
||||
// parseURL parses and validates a service URL
|
||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
|
||||
// apply() observes the supplied Policy. The original loader is restored at
|
||||
// test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.Empty(t, cfg.Policy().ManagedKeys())
|
||||
|
||||
// Default management URL still resolves.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
}
|
||||
|
||||
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
|
||||
const mdmURL = "https://corp.mdm.example.com:443"
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyBlockInbound: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.DisableClientRoutes)
|
||||
assert.True(t, cfg.BlockInbound)
|
||||
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyBlockInbound))
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyAllowServerSSH))
|
||||
}
|
||||
|
||||
func TestApply_MDMBeatsCLIInput(t *testing.T) {
|
||||
const mdmURL = "https://mdm.example.com:443"
|
||||
const cliURL = "https://cli.example.com:443"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
ManagementURL: cliURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// MDM wins over CLI-supplied management URL.
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "not-a-url",
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Invalid MDM URL is logged and skipped: default URL stays in place
|
||||
// to keep the client functional.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
|
||||
// But the key is still considered MDM-managed (admin intent is to
|
||||
// enforce, daemon rejects user writes to this field — phase-1 scaffolding
|
||||
// reflects this by keeping Policy.HasKey true even on parse failure).
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Seed without MDM.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
_, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: tmp,
|
||||
DisableClientRoutes: boolPtr(false),
|
||||
RosenpassEnabled: boolPtr(false),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now enable MDM enforcement for these keys.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
|
||||
assert.True(t, cfg.RosenpassEnabled)
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled))
|
||||
}
|
||||
|
||||
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
|
||||
const maskSentinel = "**********"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyPreSharedKey: maskSentinel,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Mask sentinel must not be persisted as the actual PSK.
|
||||
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
|
||||
// Key still marked managed so user writes are still rejected.
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -59,3 +60,22 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveProfileState deletes the per-profile state file (which holds the
|
||||
// account email used for the SSO login hint and the UI display). Called after
|
||||
// a successful logout so a logged-out profile no longer shows a stale account
|
||||
// email. The state file only stores the email, so deleting it is equivalent to
|
||||
// clearing it; the next SSO login recreates it. A missing file is not an error.
|
||||
func (pm *ProfileManager) RemoveProfileState(profileName string) error {
|
||||
configDir, err := getConfigDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config directory: %w", err)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, profileName+".state.json")
|
||||
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove profile state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
191
client/internal/routemanager/exit_node_selection_test.go
Normal file
191
client/internal/routemanager/exit_node_selection_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func newExitNodeTestManager() *DefaultManager {
|
||||
return &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
}
|
||||
|
||||
func exitRoute(netID, peer string, skipAutoApply bool) *route.Route {
|
||||
return &route.Route{
|
||||
NetID: route.NetID(netID),
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
Peer: peer,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPickPreferredExitNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
info exitNodeInfo
|
||||
want route.NetID
|
||||
}{
|
||||
{
|
||||
name: "persisted user selection wins over management",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
userSelected: []route.NetID{"b"},
|
||||
selectedByManagement: []route.NetID{"a"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "multiple user-selected self-heal to deterministic min",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
userSelected: []route.NetID{"c", "a"},
|
||||
},
|
||||
want: "a",
|
||||
},
|
||||
{
|
||||
name: "explicit opt-out keeps none",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b"},
|
||||
userDeselected: []route.NetID{"a", "b"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "fresh defaults to management auto-apply pick",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
selectedByManagement: []route.NetID{"b"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "no user pick and no management auto-apply selects none",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"c", "a", "b"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "user-deselect does not block a management auto-apply sibling",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b"},
|
||||
userDeselected: []route.NetID{"a"},
|
||||
selectedByManagement: []route.NetID{"b"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, pickPreferredExitNode(tt.info), "preferred exit node")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceSingleExitNode(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
all := []route.NetID{"a", "b", "c"}
|
||||
|
||||
m.enforceSingleExitNode("b", all)
|
||||
assert.False(t, m.routeSelector.IsSelected("a"), "a should be deselected")
|
||||
assert.True(t, m.routeSelector.IsSelected("b"), "b should be the only selected exit node")
|
||||
assert.False(t, m.routeSelector.IsSelected("c"), "c should be deselected")
|
||||
|
||||
// Switching the preferred node moves the single selection.
|
||||
m.enforceSingleExitNode("c", all)
|
||||
assert.False(t, m.routeSelector.IsSelected("a"), "a stays deselected")
|
||||
assert.False(t, m.routeSelector.IsSelected("b"), "b should now be deselected")
|
||||
assert.True(t, m.routeSelector.IsSelected("c"), "c should now be selected")
|
||||
|
||||
// Empty preferred turns every exit node off.
|
||||
m.enforceSingleExitNode("", all)
|
||||
for _, id := range all {
|
||||
assert.False(t, m.routeSelector.IsSelected(id), "no exit node should be selected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceSingleExitNode_RespectsDeselectAll(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
|
||||
m.enforceSingleExitNode("b", []route.NetID{"a", "b"})
|
||||
|
||||
assert.True(t, m.routeSelector.IsDeselectAllActive(), "global deselect-all must stay in effect")
|
||||
assert.False(t, m.routeSelector.IsSelected("b"), "no exit node should be forced on while deselect-all is set")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_FreshSelectsOne(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
"lan|192.168.1.0/24": {{NetID: "lan", Network: netip.MustParsePrefix("192.168.1.0/24"), Peer: "p3"}},
|
||||
"exitC|0.0.0.0/0": {exitRoute("exitC", "p4", false)},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
// Exactly one exit node (the deterministic first) is selected.
|
||||
assert.True(t, m.routeSelector.IsSelected("exitA"), "exitA is the deterministic default")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "exitB must not also be selected")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitC"), "exitC must not also be selected")
|
||||
// Non-exit routes are left at their default-on state.
|
||||
assert.True(t, m.routeSelector.IsSelected("lan"), "non-exit route selection is untouched")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_HonorsPersistedPick(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
}
|
||||
all := []route.NetID{"exitA", "exitB"}
|
||||
|
||||
// Simulate the state the runtime select path leaves behind: exactly one
|
||||
// exit node explicitly selected, its sibling deselected.
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exitB"}, true, all))
|
||||
require.NoError(t, m.routeSelector.DeselectRoutes([]route.NetID{"exitA"}, all))
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.True(t, m.routeSelector.IsSelected("exitB"), "persisted pick must stay selected")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "the other exit node stays deselected")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_OptOutKeepsNone(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
}
|
||||
all := []route.NetID{"exitA", "exitB"}
|
||||
|
||||
// User deselected exit nodes and selected none.
|
||||
require.NoError(t, m.routeSelector.DeselectRoutes(all, all))
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "opt-out keeps exitA off")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "opt-out keeps exitB off")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_NoAutoApplySelectsNone(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
// SkipAutoApply=true: management offers the exit nodes but doesn't request
|
||||
// auto-activation, so none should be selected until the user picks one.
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", true)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", true)},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "no auto-apply keeps exitA off")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "no auto-apply keeps exitB off")
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -439,6 +440,11 @@ func (m *DefaultManager) UpdateRoutes(
|
||||
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||
// A new network map can add or drop route/exit-node candidates without
|
||||
// touching any peer's chosen-route state, so the peer status alone
|
||||
// wouldn't notify SubscribeStatus subscribers. Bump the revision so the
|
||||
// UI re-fetches ListNetworks.
|
||||
m.statusRecorder.BumpNetworksRevision()
|
||||
}
|
||||
m.clientRoutes = clientRoutes
|
||||
|
||||
@@ -579,6 +585,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
// A selection change flips Network.selected without altering the candidate
|
||||
// set, so bump the revision to push the new state to the UI.
|
||||
m.statusRecorder.BumpNetworksRevision()
|
||||
}
|
||||
|
||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||
@@ -698,7 +708,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
return ips
|
||||
}
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
// updateRouteSelectorFromManagement reconciles exit-node selection on every
|
||||
// network map: it keeps at most one exit node selected — the user's persisted
|
||||
// pick, else whatever management marks for auto-apply (SkipAutoApply=false),
|
||||
// else none. We never auto-activate an exit node the map doesn't request; it
|
||||
// stays off until the user picks it. Exit nodes are mutually exclusive, but the
|
||||
// RouteSelector stores routes with default-on semantics, so without this every
|
||||
// available exit node would report selected at once.
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
@@ -707,13 +723,14 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
info := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(info.allIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
m.updateExitNodeSelections(exitNodeInfo)
|
||||
m.logExitNodeUpdate(exitNodeInfo)
|
||||
preferred := pickPreferredExitNode(info)
|
||||
m.enforceSingleExitNode(preferred, info.allIDs)
|
||||
m.logExitNodeUpdate(info, preferred)
|
||||
}
|
||||
|
||||
type exitNodeInfo struct {
|
||||
@@ -723,6 +740,10 @@ type exitNodeInfo struct {
|
||||
userDeselected []route.NetID
|
||||
}
|
||||
|
||||
// collectExitNodeInfo categorises the available exit nodes by their persisted
|
||||
// selection state. It keys on the base (v4) NetID and skips the synthesized
|
||||
// "-v6" partner, which inherits its base's selection through the RouteSelector
|
||||
// — counting it separately would double-count the pair.
|
||||
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
|
||||
var info exitNodeInfo
|
||||
|
||||
@@ -732,6 +753,9 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
|
||||
}
|
||||
|
||||
netID := haID.NetID()
|
||||
if strings.HasSuffix(string(netID), route.V6ExitSuffix) {
|
||||
continue
|
||||
}
|
||||
info.allIDs = append(info.allIDs, netID)
|
||||
|
||||
if m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||
@@ -768,45 +792,69 @@ func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID r
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
|
||||
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
|
||||
m.deselectExitNodes(routesToDeselect)
|
||||
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
|
||||
// pickPreferredExitNode chooses the single exit node to keep selected. In order:
|
||||
// - a persisted user selection wins (deterministic if several survive from
|
||||
// legacy state, so the set self-heals down to one);
|
||||
// - otherwise activate only what management marks for auto-apply
|
||||
// (SkipAutoApply=false); the lexicographically first if it marks several.
|
||||
//
|
||||
// Returns "" when neither holds — we never force an arbitrary exit node on. A
|
||||
// route the map doesn't auto-apply stays off until the user selects it.
|
||||
// info.userDeselected is informational only: an explicit deselect simply keeps
|
||||
// that route out of both lists above, so it can't be picked.
|
||||
func pickPreferredExitNode(info exitNodeInfo) route.NetID {
|
||||
if len(info.userSelected) > 0 {
|
||||
return minNetID(info.userSelected)
|
||||
}
|
||||
if len(info.selectedByManagement) > 0 {
|
||||
return minNetID(info.selectedByManagement)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
|
||||
var routesToDeselect []route.NetID
|
||||
for _, netID := range allIDs {
|
||||
if !m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||
routesToDeselect = append(routesToDeselect, netID)
|
||||
// enforceSingleExitNode makes preferred the only selected exit node: every other
|
||||
// available exit node is deselected and preferred (if any) is selected, without
|
||||
// disturbing non-exit route selections. A global deselect-all is left untouched
|
||||
// so the user's "all off" stays in effect.
|
||||
func (m *DefaultManager) enforceSingleExitNode(preferred route.NetID, allIDs []route.NetID) {
|
||||
if m.routeSelector.IsDeselectAllActive() {
|
||||
return
|
||||
}
|
||||
|
||||
others := make([]route.NetID, 0, len(allIDs))
|
||||
for _, id := range allIDs {
|
||||
if id != preferred {
|
||||
others = append(others, id)
|
||||
}
|
||||
}
|
||||
return routesToDeselect
|
||||
}
|
||||
|
||||
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
|
||||
if len(routesToDeselect) == 0 {
|
||||
return
|
||||
if len(others) > 0 {
|
||||
if err := m.routeSelector.DeselectRoutes(others, allIDs); err != nil {
|
||||
log.Warnf("deselect other exit nodes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to deselect exit nodes: %v", err)
|
||||
if preferred != "" {
|
||||
if err := m.routeSelector.SelectRoutes([]route.NetID{preferred}, true, allIDs); err != nil {
|
||||
log.Warnf("select preferred exit node %q: %v", preferred, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
|
||||
if len(selectedByManagement) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to select exit nodes: %v", err)
|
||||
}
|
||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo, preferred route.NetID) {
|
||||
log.Debugf("Exit node selection: %d available, preferred=%q (%d user-selected, %d user-deselected, %d management-selected)",
|
||||
len(info.allIDs), preferred, len(info.userSelected), len(info.userDeselected), len(info.selectedByManagement))
|
||||
}
|
||||
|
||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
|
||||
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
|
||||
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
|
||||
// minNetID returns the lexicographically smallest NetID, for a deterministic
|
||||
// default pick that stays stable across restarts.
|
||||
func minNetID(ids []route.NetID) route.NetID {
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
best := ids[0]
|
||||
for _, id := range ids[1:] {
|
||||
if id < best {
|
||||
best = id
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build privileged
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||
|
||||
content := []string{
|
||||
"1000 reserved",
|
||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||
"9999 other_table",
|
||||
}
|
||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||
|
||||
file, err := os.Open(tempFilePath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, file.Close())
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
shouldExist bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ExistsWithNetbirdPrefix",
|
||||
id: 7120,
|
||||
shouldExist: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ExistsWithDifferentName",
|
||||
id: 1000,
|
||||
shouldExist: true,
|
||||
err: ErrTableIDExists,
|
||||
},
|
||||
{
|
||||
name: "DoesNotExist",
|
||||
id: 1234,
|
||||
shouldExist: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exists, err := entryExists(file, tc.id)
|
||||
if tc.err != nil {
|
||||
assert.ErrorIs(t, err, tc.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.shouldExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
if dstCIDR == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, err = fetchOriginalGateway()
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if originalNexthop != nil {
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||
assert.NoError(t, err, "Failed to restore original route")
|
||||
}
|
||||
})
|
||||
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||
require.NoError(t, err, "Failed to add route")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||
assert.NoError(t, err, "Failed to remove route")
|
||||
})
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (net.IP, error) {
|
||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("gateway not found")
|
||||
}
|
||||
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
@@ -3,24 +3,79 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files in this package compile.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedVPNint = "utun100"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedExternalInt = "lo0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedInternalInt = "lo0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -67,3 +122,122 @@ func TestBits(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
if dstCIDR == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, err = fetchOriginalGateway()
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if originalNexthop != nil {
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||
assert.NoError(t, err, "Failed to restore original route")
|
||||
}
|
||||
})
|
||||
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||
require.NoError(t, err, "Failed to add route")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||
assert.NoError(t, err, "Failed to remove route")
|
||||
})
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (net.IP, error) {
|
||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("gateway not found")
|
||||
}
|
||||
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// dialer is shared by the per-platform routing test cases. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files compile on every platform.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
type dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android && !ios && privileged
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
@@ -26,6 +26,11 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func TestAddVPNRoute(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -510,3 +515,125 @@ func setupTestEnv(t *testing.T) {
|
||||
// unique route in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
}
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
//go:build !android && privileged
|
||||
//go:build !android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
@@ -15,6 +18,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
var expectedExternalInt = "dummyext0"
|
||||
var expectedInternalInt = "dummyint0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
@@ -26,6 +33,62 @@ func init() {
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||
|
||||
content := []string{
|
||||
"1000 reserved",
|
||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||
"9999 other_table",
|
||||
}
|
||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||
|
||||
file, err := os.Open(tempFilePath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, file.Close())
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
shouldExist bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ExistsWithNetbirdPrefix",
|
||||
id: 7120,
|
||||
shouldExist: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ExistsWithDifferentName",
|
||||
id: 1000,
|
||||
shouldExist: true,
|
||||
err: ErrTableIDExists,
|
||||
},
|
||||
{
|
||||
name: "DoesNotExist",
|
||||
id: 1234,
|
||||
shouldExist: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exists, err := entryExists(file, tc.id)
|
||||
if tc.err != nil {
|
||||
assert.ErrorIs(t, err, tc.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.shouldExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files in this package compile.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedVPNint = "wgtest0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedExternalInt = "dummyext0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedInternalInt = "dummyint0"
|
||||
@@ -1,83 +0,0 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// Shared, non-privileged routing test fixtures. The privileged TestRouting (and its
|
||||
// per-platform init() appenders) consume these; they live here so the unprivileged
|
||||
// BSD/darwin test files compile without the privileged build tag.
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
type testCase struct {
|
||||
name string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
expectedPacket PacketExpectation
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly) && privileged
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package systemops
|
||||
|
||||
@@ -20,6 +20,63 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
expectedPacket PacketExpectation
|
||||
}
|
||||
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
}
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
nbnet.Init()
|
||||
for _, tc := range testCases {
|
||||
@@ -45,6 +102,16 @@ func TestRouting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
||||
|
||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android && privileged
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
|
||||
@@ -8,14 +8,11 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
||||
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -132,6 +132,16 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
return rs.isSelectedLocked(routeID)
|
||||
}
|
||||
|
||||
// IsDeselectAllActive reports whether the global "deselect all" flag is set,
|
||||
// i.e. the user disabled every route. Callers enforcing per-route invariants
|
||||
// (e.g. single exit node) should leave the selection untouched when it is.
|
||||
func (rs *RouteSelector) IsDeselectAllActive() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -825,3 +825,31 @@ func TestRouteSelector_ComplexScenarios(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouteSelector_EnableExitNodeKeepsOtherRoutes is a regression test for the
|
||||
// tray exit-node toggle disabling every non-exit routed network. The tray used
|
||||
// to Select an exit node with append=false, which the RouteSelector treats as
|
||||
// "drop the whole current selection" (default-on semantics) — so enabling an
|
||||
// exit node also turned off every LAN/route the user had on. The fix sends
|
||||
// append=true and lets the daemon's SelectNetworks handler deselect only the
|
||||
// sibling exit nodes. This test models that handler sequence against the
|
||||
// selector: SelectRoutes(exit, append=true) followed by DeselectRoutes(other
|
||||
// exit nodes) must leave non-exit routes untouched.
|
||||
func TestRouteSelector_EnableExitNodeKeepsOtherRoutes(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
all := []route.NetID{"exitA", "exitB", "lan1", "lan2"}
|
||||
|
||||
// User has two LAN routes on (default-on: nothing deselected => all selected).
|
||||
require.True(t, rs.IsSelected("lan1"))
|
||||
require.True(t, rs.IsSelected("lan2"))
|
||||
|
||||
// Tray enables exitA: SelectNetworks handler does SelectRoutes(append=true)
|
||||
// then deselects sibling exit nodes (exitB), never the LAN routes.
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{"exitA"}, true, all))
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exitB"}, all))
|
||||
|
||||
assert.True(t, rs.IsSelected("exitA"), "selected exit node stays on")
|
||||
assert.False(t, rs.IsSelected("exitB"), "sibling exit node is deselected")
|
||||
assert.True(t, rs.IsSelected("lan1"), "non-exit route must stay selected")
|
||||
assert.True(t, rs.IsSelected("lan2"), "non-exit route must stay selected")
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type StatusType string
|
||||
@@ -33,17 +36,37 @@ func CtxGetState(ctx context.Context) *contextState {
|
||||
}
|
||||
|
||||
type contextState struct {
|
||||
err error
|
||||
status StatusType
|
||||
mutex sync.Mutex
|
||||
err error
|
||||
status StatusType
|
||||
mutex sync.Mutex
|
||||
onChange func()
|
||||
}
|
||||
|
||||
// SetOnChange installs a callback fired after every successful Set. Used by
|
||||
// the daemon to wire the status recorder's notifyStateChange so any
|
||||
// state.Set in the connect/login paths pushes a fresh snapshot to
|
||||
// SubscribeStatus subscribers without each callsite having to opt in.
|
||||
// The callback runs outside the contextState mutex to avoid a lock-order
|
||||
// dependency with the recorder's stateChangeMux.
|
||||
func (c *contextState) SetOnChange(fn func()) {
|
||||
c.mutex.Lock()
|
||||
c.onChange = fn
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *contextState) Set(update StatusType) {
|
||||
if _, file, line, ok := runtime.Caller(1); ok {
|
||||
log.Infof("--- state.Set(%s) from %s:%d", update, file, line)
|
||||
}
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.status = update
|
||||
c.err = nil
|
||||
cb := c.onChange
|
||||
c.mutex.Unlock()
|
||||
|
||||
if cb != nil {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *contextState) Status() (StatusType, error) {
|
||||
@@ -57,6 +80,17 @@ func (c *contextState) Status() (StatusType, error) {
|
||||
return c.status, nil
|
||||
}
|
||||
|
||||
// CurrentStatus returns the last status set via Set, ignoring any wrapped
|
||||
// error. Use when the status is needed for reporting purposes (e.g. the
|
||||
// status snapshot stream) and a transient wrapped error from a retry loop
|
||||
// shouldn't blank out the underlying status.
|
||||
func (c *contextState) CurrentStatus() StatusType {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *contextState) Wrap(err error) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package mdm
|
||||
|
||||
import "strings"
|
||||
|
||||
// allKeys is the set of recognised MDM keys. Unknown keys in a managed
|
||||
// configuration are ignored but logged. Lives in this build-tagged file
|
||||
// (windows || darwin) because only desktop loaders need the
|
||||
// canonicalisation table that consumes it; including it unconditionally
|
||||
// would trigger the `unused` golangci-lint check on platforms that
|
||||
// don't import canonical_loaders.go.
|
||||
var allKeys = []string{
|
||||
KeyManagementURL,
|
||||
KeyDisableUpdateSettings,
|
||||
KeyDisableProfiles,
|
||||
KeyDisableNetworks,
|
||||
KeyDisableClientRoutes,
|
||||
KeyDisableServerRoutes,
|
||||
KeyBlockInbound,
|
||||
KeyDisableMetricsCollection,
|
||||
KeyAllowServerSSH,
|
||||
KeyDisableAutoConnect,
|
||||
KeyPreSharedKey,
|
||||
KeyRosenpassEnabled,
|
||||
KeyRosenpassPermissive,
|
||||
KeyWireguardPort,
|
||||
KeySplitTunnelMode,
|
||||
KeySplitTunnelApps,
|
||||
}
|
||||
|
||||
// canonicalKey maps the lowercase form of a managed-config value name to
|
||||
// its canonical mdm.Key* form. Admins commonly write PascalCase value
|
||||
// names in ADMX / Group Policy ("ManagementURL"); the iOS/AppConfig and
|
||||
// macOS plist conventions are camelCase ("managementURL"); both must
|
||||
// resolve to the same Policy lookup.
|
||||
//
|
||||
// Lives in a desktop-loader-only file (build tag `windows || darwin`)
|
||||
// because no other build path consumes it. Linux / FreeBSD / mobile
|
||||
// builds don't ship a platform loader that reads arbitrary-case key
|
||||
// names, so they don't need the canonicalisation table — and including
|
||||
// the var unconditionally would trigger the `unused` golangci-lint
|
||||
// check on those platforms.
|
||||
var canonicalKey = func() map[string]string {
|
||||
m := make(map[string]string, len(allKeys))
|
||||
for _, k := range allKeys {
|
||||
m[strings.ToLower(k)] = k
|
||||
}
|
||||
return m
|
||||
}()
|
||||
@@ -1,247 +0,0 @@
|
||||
// Package mdm reads MDM-managed configuration from platform-native sources
|
||||
// (plist on macOS, registry on Windows, UserDefaults on iOS,
|
||||
// RestrictionsManager on Android). The returned Policy is consumed by
|
||||
// profilemanager.Config.apply() as the highest-priority override layer.
|
||||
//
|
||||
// An empty Policy (no source present, or source present with zero keys)
|
||||
// means no MDM enforcement is active and the client behaves as if the
|
||||
// feature did not exist.
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Well-known policy keys. Names mirror the corresponding ConfigInput Go field
|
||||
// names (lowerCamelCase) so the daemon can map a Policy key directly to a
|
||||
// configuration field.
|
||||
const (
|
||||
KeyManagementURL = "managementURL"
|
||||
KeyDisableUpdateSettings = "disableUpdateSettings"
|
||||
KeyDisableProfiles = "disableProfiles"
|
||||
KeyDisableNetworks = "disableNetworks"
|
||||
KeyDisableClientRoutes = "disableClientRoutes"
|
||||
KeyDisableServerRoutes = "disableServerRoutes"
|
||||
KeyBlockInbound = "blockInbound"
|
||||
KeyDisableMetricsCollection = "disableMetricsCollection"
|
||||
KeyAllowServerSSH = "allowServerSSH"
|
||||
KeyDisableAutoConnect = "disableAutoConnect"
|
||||
KeyPreSharedKey = "preSharedKey"
|
||||
KeyRosenpassEnabled = "rosenpassEnabled"
|
||||
KeyRosenpassPermissive = "rosenpassPermissive"
|
||||
KeyWireguardPort = "wireguardPort"
|
||||
|
||||
// Split tunnel is modeled as a single conceptual policy with two
|
||||
// registry/plist values. KeySplitTunnelMode is the discriminator
|
||||
// ("allow" or "disallow"); KeySplitTunnelApps is a comma-separated
|
||||
// list of package names. The values are mutually exclusive by
|
||||
// construction — only one mode can be set at a time.
|
||||
KeySplitTunnelMode = "splitTunnelMode"
|
||||
KeySplitTunnelApps = "splitTunnelApps"
|
||||
)
|
||||
|
||||
// Split-tunnel mode literals (KeySplitTunnelMode values).
|
||||
const (
|
||||
SplitTunnelModeAllow = "allow"
|
||||
SplitTunnelModeDisallow = "disallow"
|
||||
)
|
||||
|
||||
// SecretKeys lists keys whose values must be redacted in logs.
|
||||
var SecretKeys = map[string]struct{}{
|
||||
KeyPreSharedKey: {},
|
||||
}
|
||||
|
||||
// boolStringLiterals enumerates the textual boolean encodings the
|
||||
// platform loaders may produce (Windows REG_SZ "true", iOS / Android
|
||||
// managed-config booleans-as-strings, etc.). Lookup keeps GetBool flat
|
||||
// (no nested switch on the string case).
|
||||
var boolStringLiterals = map[string]bool{
|
||||
"true": true,
|
||||
"1": true,
|
||||
"yes": true,
|
||||
"false": false,
|
||||
"0": false,
|
||||
"no": false,
|
||||
}
|
||||
|
||||
|
||||
// Policy holds MDM-managed settings read from the platform source. A nil or
|
||||
// empty Policy means no enforcement is active.
|
||||
type Policy struct {
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
// NewPolicy constructs a Policy from a key→value map. Pass nil or an
|
||||
// empty map to construct an empty (no-enforcement) Policy. The returned
|
||||
// *Policy is always non-nil.
|
||||
func NewPolicy(values map[string]any) *Policy {
|
||||
if values == nil {
|
||||
values = map[string]any{}
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// LoadPolicy reads the platform-native MDM configuration. Returns an
|
||||
// empty (but non-nil) Policy when no source is present, the source is
|
||||
// empty, or the platform is unsupported.
|
||||
//
|
||||
// Diagnostic logging differentiates the three states:
|
||||
// - source absent / unsupported platform: trace log only
|
||||
// - source present, zero keys: info "MDM enrolled (no managed keys)"
|
||||
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
|
||||
func LoadPolicy() *Policy {
|
||||
values, err := loadPlatformPolicy()
|
||||
if err != nil {
|
||||
log.Tracef("MDM policy load: %v", err)
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if values == nil {
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if len(values) == 0 {
|
||||
log.Info("MDM enrolled (no managed keys)")
|
||||
} else {
|
||||
log.Infof("MDM enrolled with %d managed key(s): %v", len(values), sortedKeys(values))
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// IsEmpty reports whether the Policy has no managed keys.
|
||||
func (p *Policy) IsEmpty() bool {
|
||||
return p == nil || len(p.values) == 0
|
||||
}
|
||||
|
||||
// HasKey reports whether the given key is MDM-managed.
|
||||
func (p *Policy) HasKey(key string) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := p.values[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ManagedKeys returns the sorted list of managed key names. Returns an empty
|
||||
// slice (not nil) on an empty Policy.
|
||||
func (p *Policy) ManagedKeys() []string {
|
||||
if p == nil {
|
||||
return []string{}
|
||||
}
|
||||
return sortedKeys(p.values)
|
||||
}
|
||||
|
||||
// GetString returns the managed value for key coerced to string, and whether
|
||||
// the key was set. A non-string value returns ("", false).
|
||||
func (p *Policy) GetString(key string) (string, bool) {
|
||||
if p == nil {
|
||||
return "", false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok || s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// GetBool returns the managed value for key coerced to bool, and whether the
|
||||
// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0".
|
||||
func (p *Policy) GetBool(key string) (bool, bool) {
|
||||
if p == nil {
|
||||
return false, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
b, known := boolStringLiterals[t]
|
||||
return b, known
|
||||
case int:
|
||||
return t != 0, true
|
||||
case int64:
|
||||
return t != 0, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetInt returns the managed value for key as int64, and whether the key
|
||||
// was set. Accepts native int / int64 (as produced by the Windows registry
|
||||
// loader for REG_DWORD/REG_QWORD) and numeric strings (decimal).
|
||||
func (p *Policy) GetInt(key string) (int64, bool) {
|
||||
if p == nil {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
return t, true
|
||||
case int:
|
||||
return int64(t), true
|
||||
case int32:
|
||||
return int64(t), true
|
||||
case uint64:
|
||||
return int64(t), true
|
||||
case float64:
|
||||
return int64(t), true
|
||||
case string:
|
||||
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetStringSlice returns the managed value for key as []string, and whether
|
||||
// the key was set. Accepts []string, []any (of strings), and a single string
|
||||
// (treated as a one-element list).
|
||||
func (p *Policy) GetStringSlice(key string) ([]string, bool) {
|
||||
if p == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case []string:
|
||||
return append([]string(nil), t...), true
|
||||
case []any:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, item := range t {
|
||||
s, ok := item.(string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out = append(out, s)
|
||||
}
|
||||
return out, true
|
||||
case string:
|
||||
return []string{t}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// sortedKeys returns the keys of m as a deterministic, lexicographically
|
||||
// sorted slice. Used internally by Policy.ManagedKeys and LoadPolicy's
|
||||
// diagnostic log line so callers see a stable key order across runs
|
||||
// regardless of Go's randomised map iteration.
|
||||
func sortedKeys(m map[string]any) []string {
|
||||
out := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
out = append(out, k)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
// policyPlistPath is the well-known location where macOS writes the
|
||||
// device-level mandatory MDM payload for NetBird. The path is fixed by
|
||||
// Apple convention: when an MDM provider (Jamf / Kandji / Mosyle /
|
||||
// Intune for Mac / Workspace ONE) pushes a Configuration Profile that
|
||||
// contains a com.apple.ManagedClient.preferences payload targeting the
|
||||
// bundle id io.netbird.client, the OS materializes the payload here.
|
||||
//
|
||||
// Read-only — only the OS (root) is supposed to write this file. The
|
||||
// loader sanity-checks the file mode and refuses to honour a world-
|
||||
// writable plist, as a defense against tampered installs.
|
||||
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
|
||||
// managed-preferences plist at policyPlistPath. Returns:
|
||||
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
|
||||
// NetBird, or admin has not yet pushed a payload)
|
||||
// - (map, nil) with N entries when N managed values are present
|
||||
// (N may be 0 — empty plist still signals enrollment to the caller)
|
||||
// - (nil, err) on permission / parse / safety errors (including
|
||||
// refusal to read a world-writable plist)
|
||||
//
|
||||
// Top-level plist keys are canonicalised case-insensitively to the
|
||||
// package's internal mdm.Key* names; unknown keys are logged and
|
||||
// skipped so a stray entry in the payload does not block startup.
|
||||
// Native plist value types map naturally onto the Policy accessor
|
||||
// expectations (GetString / GetBool / GetInt / GetStringSlice).
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
f, err := os.Open(policyPlistPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// Not enrolled for NetBird. Caller treats nil as
|
||||
// "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyPlistPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close plist %s: %v", policyPlistPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat %s: %w", policyPlistPath, err)
|
||||
}
|
||||
// World-writable plist => tampered install. Refuse rather than
|
||||
// honour potentially attacker-controlled policy values.
|
||||
if info.Mode().Perm()&0o002 != 0 {
|
||||
return nil, fmt.Errorf("refusing to read world-writable MDM source %s (mode %o)",
|
||||
policyPlistPath, info.Mode().Perm())
|
||||
}
|
||||
|
||||
raw := make(map[string]any)
|
||||
if err := plist.NewDecoder(f).Decode(&raw); err != nil {
|
||||
return nil, fmt.Errorf("decode plist %s: %w", policyPlistPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(raw))
|
||||
for name, val := range raw {
|
||||
// macOS / AppConfig conventions both use camelCase for managed
|
||||
// preferences keys; canonicalize to the mdm.Key* form so a key
|
||||
// written as "ManagementURL" (PascalCase, rare on macOS but
|
||||
// possible if the admin reused an ADMX-style name) still
|
||||
// resolves.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown plist key %s: %s", policyPlistPath, name)
|
||||
continue
|
||||
}
|
||||
out[canonical] = val
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build ios || android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
|
||||
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
|
||||
// resulting dictionary in-process via a gomobile entry point that lands in
|
||||
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
|
||||
// builds and returns (nil, nil) — the platform-absent sentinel that
|
||||
// LoadPolicy in policy.go treats as "no MDM source present".
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !windows && !darwin && !ios && !android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy returns no policy on platforms without an MDM channel
|
||||
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
|
||||
// the feature did not exist. Returns (nil, nil) — the platform-absent
|
||||
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
|
||||
// source present"; an error here would just translate to the same
|
||||
// outcome with an extra log line.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPolicy_NilSafe(t *testing.T) {
|
||||
var p *Policy
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
|
||||
_, ok := p.GetString(KeyManagementURL)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_Empty(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
|
||||
func TestPolicy_HasKey(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
})
|
||||
assert.False(t, p.IsEmpty())
|
||||
assert.True(t, p.HasKey(KeyManagementURL))
|
||||
assert.True(t, p.HasKey(KeyDisableProfiles))
|
||||
assert.False(t, p.HasKey(KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func TestPolicy_ManagedKeysSorted(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyDisableProfiles: true,
|
||||
KeyManagementURL: "https://x",
|
||||
KeyAllowServerSSH: false,
|
||||
})
|
||||
got := p.ManagedKeys()
|
||||
assert.Equal(t, []string{KeyAllowServerSSH, KeyDisableProfiles, KeyManagementURL}, got)
|
||||
}
|
||||
|
||||
func TestPolicy_GetString(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
})
|
||||
v, ok := p.GetString(KeyManagementURL)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "https://corp.example.com", v)
|
||||
|
||||
_, ok = p.GetString(KeyDisableProfiles)
|
||||
assert.False(t, ok, "non-string value must not be reported as string")
|
||||
|
||||
_, ok = p.GetString(KeyPreSharedKey)
|
||||
assert.False(t, ok, "empty string treated as unset")
|
||||
|
||||
_, ok = p.GetString("nonexistent")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetBool(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw any
|
||||
want bool
|
||||
ok bool
|
||||
}{
|
||||
{"native true", true, true, true},
|
||||
{"native false", false, false, true},
|
||||
{"string true", "true", true, true},
|
||||
{"string false", "false", false, true},
|
||||
{"string 1", "1", true, true},
|
||||
{"string 0", "0", false, true},
|
||||
{"string yes", "yes", true, true},
|
||||
{"string no", "no", false, true},
|
||||
{"int nonzero", 1, true, true},
|
||||
{"int zero", 0, false, true},
|
||||
{"int64 nonzero", int64(2), true, true},
|
||||
{"int64 zero", int64(0), false, true},
|
||||
{"string garbage", "maybe", false, false},
|
||||
{"float unsupported", 1.0, false, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{KeyDisableProfiles: c.raw})
|
||||
got, ok := p.GetBool(KeyDisableProfiles)
|
||||
assert.Equal(t, c.ok, ok)
|
||||
if c.ok {
|
||||
assert.Equal(t, c.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
_, ok := NewPolicy(nil).GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetStringSlice(t *testing.T) {
|
||||
t.Run("native string slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []string{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("any slice of strings", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("single string lifts to one-element slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: "com.a",
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a"}, got)
|
||||
})
|
||||
|
||||
t.Run("mixed any slice rejected", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", 1},
|
||||
})
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("missing key", func(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
|
||||
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
|
||||
// degrade gracefully and never return nil.
|
||||
p := LoadPolicy()
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
// policyRegistryPath is the well-known MDM policy registry key for NetBird.
|
||||
// Admins push values here through Group Policy, Intune ADMX ingestion, an
|
||||
// Intune custom Registry CSP profile, or `reg add` during MSI deployment.
|
||||
// Listed in the project's docs/mdm/netbird.admx schema.
|
||||
const policyRegistryPath = `Software\Policies\NetBird`
|
||||
|
||||
// readRegistryValue reads a single value under policyRegistryPath and,
|
||||
// on success, stores the type-coerced result in out[canonical]. Type
|
||||
// coercion mirrors loadPlatformPolicy's documented mapping:
|
||||
// - REG_SZ / REG_EXPAND_SZ -> string (REG_EXPAND_SZ is expanded by the API)
|
||||
// - REG_DWORD / REG_QWORD -> int64
|
||||
// - REG_MULTI_SZ -> []string
|
||||
//
|
||||
// Unsupported value types and per-value read failures are logged at
|
||||
// warn level and skipped — one malformed value must not block the
|
||||
// surrounding loop. Extracted from loadPlatformPolicy to keep that
|
||||
// function's cognitive complexity in check.
|
||||
func readRegistryValue(k registry.Key, name, canonical string, out map[string]any) {
|
||||
_, valType, err := k.GetValue(name, nil)
|
||||
if err != nil {
|
||||
log.Warnf("MDM stat %s\\%s: %v", policyRegistryPath, name, err)
|
||||
return
|
||||
}
|
||||
switch valType {
|
||||
case registry.SZ, registry.EXPAND_SZ:
|
||||
if v, _, err := k.GetStringValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.DWORD, registry.QWORD:
|
||||
if v, _, err := k.GetIntegerValue(name); err == nil {
|
||||
// uint64 from the registry API; Policy.GetBool / GetInt
|
||||
// helpers consume int64, so narrow safely.
|
||||
out[canonical] = int64(v)
|
||||
} else {
|
||||
log.Warnf("MDM read int %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.MULTI_SZ:
|
||||
if v, _, err := k.GetStringsValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read multi-string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
default:
|
||||
log.Warnf("MDM ignoring unsupported registry value type %d at %s\\%s",
|
||||
valType, policyRegistryPath, name)
|
||||
}
|
||||
}
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the
|
||||
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
|
||||
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
|
||||
// - (map, nil) with N entries when N managed values are set (N may be 0)
|
||||
// - (nil, err) on open / enumerate registry errors
|
||||
//
|
||||
// Per-value type coercion + skip-on-error is delegated to
|
||||
// readRegistryValue. Unknown value names are logged and skipped so a
|
||||
// malformed deployment does not block startup.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if errors.Is(err, registry.ErrNotExist) {
|
||||
// Not enrolled. Caller treats nil as "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := k.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close registry key %s: %v", policyRegistryPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
names, err := k.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("enumerate values of %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(names))
|
||||
for _, name := range names {
|
||||
// Canonicalize the registry value name against the known MDM key
|
||||
// set so Policy.HasKey lookups (which use the canonical names)
|
||||
// succeed regardless of the casing used by the admin's ADMX or
|
||||
// `reg add` command.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown registry value %s\\%s", policyRegistryPath, name)
|
||||
continue
|
||||
}
|
||||
readRegistryValue(k, name, canonical, out)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultReloadInterval is the production cadence at which the desktop daemon
|
||||
// re-reads the OS-native MDM policy. Picked to balance responsiveness against
|
||||
// registry/plist I/O overhead. Mobile builds use OS-side notifications
|
||||
// instead, hence anticipating the ticker mechanism entirely.
|
||||
const DefaultReloadInterval = 1 * time.Minute
|
||||
|
||||
// policyLoader is the indirection through which the ticker reads the
|
||||
// OS-native policy, both for the initial observation and on every tick.
|
||||
// Production points it at LoadPolicy; tests in this package override it to
|
||||
// feed a scripted sequence of policies without touching the real OS store.
|
||||
var policyLoader = LoadPolicy
|
||||
|
||||
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
|
||||
// invokes the onChange callback (supplied to Run) whenever the observed
|
||||
// Policy diverges from the last observation (added / removed / changed
|
||||
// keys). Launch with Run from a goroutine; cancel the supplied context
|
||||
// to stop.
|
||||
type Ticker struct {
|
||||
interval time.Duration
|
||||
prev *Policy
|
||||
}
|
||||
|
||||
// NewTicker constructs a Ticker that will re-read the OS-native policy
|
||||
// every reloadInterval once Run is called.
|
||||
// The initial snapshot is populated by calling policyLoader at
|
||||
// construction time so the first tick only fires
|
||||
// onChange when the policy actually changed since boot — without
|
||||
// this baseline the first tick would report every currently-managed
|
||||
// key as "added" and trigger a spurious engine restart.
|
||||
func NewTicker(reloadInterval time.Duration) *Ticker {
|
||||
return &Ticker{
|
||||
interval: reloadInterval,
|
||||
prev: policyLoader(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run blocks until ctx is cancelled, polling the OS-native policy store at
|
||||
// the configured cadence and emitting log lines + onChange callback on
|
||||
// every observed diff. onChange must be non-nil.
|
||||
func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) error) {
|
||||
tk := time.NewTicker(t.interval)
|
||||
defer tk.Stop()
|
||||
log.Infof("MDM policy reload ticker started (interval=%s)", t.interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("MDM policy reload ticker stopped")
|
||||
return
|
||||
case <-tk.C:
|
||||
curr := policyLoader()
|
||||
if policiesEqual(t.prev, curr) {
|
||||
continue
|
||||
}
|
||||
added, removed, changed := diffPolicies(t.prev, curr)
|
||||
log.Infof("MDM policy changed: added=%v removed=%v changed=%v",
|
||||
added, removed, changed)
|
||||
prev := t.prev
|
||||
if err := onChange(prev, curr); err != nil {
|
||||
log.Errorf("MDM policy change handler failed (retrying in 1 minute): %v", err)
|
||||
continue
|
||||
}
|
||||
t.prev = curr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policiesEqual reports whether two Policy instances carry the same
|
||||
// managed key set with identical values. Nil and empty policies
|
||||
// compare equal; one-nil/one-non-empty compare not equal; otherwise
|
||||
// the underlying values maps are compared with reflect.DeepEqual.
|
||||
func policiesEqual(a, b *Policy) bool {
|
||||
if a.IsEmpty() && b.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return reflect.DeepEqual(a.values, b.values)
|
||||
}
|
||||
|
||||
// diffPolicies returns the keys added in curr, removed from prev, and
|
||||
// whose values changed between prev and curr. Each slice is sorted
|
||||
// lexicographically for stable log output; value differences are
|
||||
// determined with reflect.DeepEqual.
|
||||
func diffPolicies(prev, curr *Policy) (added, removed, changed []string) {
|
||||
prevKVs := mapOf(prev)
|
||||
currKVs := mapOf(curr)
|
||||
for k := range currKVs {
|
||||
if _, ok := prevKVs[k]; !ok {
|
||||
added = append(added, k)
|
||||
} else if !reflect.DeepEqual(prevKVs[k], currKVs[k]) {
|
||||
changed = append(changed, k)
|
||||
}
|
||||
}
|
||||
for k := range prevKVs {
|
||||
if _, ok := currKVs[k]; !ok {
|
||||
removed = append(removed, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(added)
|
||||
sort.Strings(removed)
|
||||
sort.Strings(changed)
|
||||
return added, removed, changed
|
||||
}
|
||||
|
||||
// mapOf returns a (possibly empty, never nil) copy of the underlying
|
||||
// values map of a Policy so callers outside this package can compare
|
||||
// keys/values across the type boundary. Returns an empty map on nil p.
|
||||
func mapOf(p *Policy) map[string]any {
|
||||
if p == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(p.values))
|
||||
for k, v := range p.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testReloadInterval for speeding up the ticker cadence under `go test`
|
||||
const testReloadInterval = 1 * time.Second
|
||||
|
||||
// withPolicyLoader overrides the package-level policyLoader for the duration
|
||||
// of the test so the ticker observes a scripted policy instead of the real
|
||||
// OS-native store. The original loader is restored on cleanup.
|
||||
func withPolicyLoader(t *testing.T, fn func() *Policy) {
|
||||
t.Helper()
|
||||
prev := policyLoader
|
||||
policyLoader = fn
|
||||
t.Cleanup(func() { policyLoader = prev })
|
||||
}
|
||||
|
||||
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
current := NewPolicy(nil) // initial observation: empty (no enforcement)
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return current
|
||||
})
|
||||
|
||||
type change struct{ prev, curr *Policy }
|
||||
changes := make(chan change, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
require.Equal(t, testReloadInterval, tk.interval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(prev, curr *Policy) error {
|
||||
select {
|
||||
case changes <- change{prev, curr}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
// Stop Run and wait for it to exit before returning, so the policyLoader
|
||||
// restore in t.Cleanup can't race the ticker goroutine still reading it.
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Flip the OS-observed policy from empty to one managed key. The next
|
||||
// tick must detect the diff and invoke onChange.
|
||||
mu.Lock()
|
||||
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
|
||||
mu.Unlock()
|
||||
|
||||
select {
|
||||
case c := <-changes:
|
||||
assert.True(t, c.prev.IsEmpty(), "prev should be the initial empty policy")
|
||||
assert.True(t, c.curr.HasKey(KeyManagementURL), "curr should carry the newly-pushed managed key")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("onChange not invoked within 5s; ticker should fire every 1s under test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
return NewPolicy(map[string]any{KeyBlockInbound: true})
|
||||
})
|
||||
|
||||
fired := make(chan struct{}, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(_, _ *Policy) error {
|
||||
select {
|
||||
case fired <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
|
||||
// diff guard must suppress the callback entirely.
|
||||
select {
|
||||
case <-fired:
|
||||
t.Fatal("onChange fired despite an unchanged policy")
|
||||
case <-time.After(2500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
@@ -32,9 +32,6 @@
|
||||
</File>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
||||
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
|
||||
<?if $(var.ArchSuffix) = "amd64" ?>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
||||
<?endif ?>
|
||||
|
||||
<ServiceInstall
|
||||
Id="NetBirdService"
|
||||
@@ -62,6 +59,14 @@
|
||||
<Component Id="NetbirdAumidRegistry" Guid="*">
|
||||
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
|
||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
||||
<!-- Pre-seed the CLSID the Wails notifications service reads on
|
||||
first startup (notifications_windows.go:getGUID looks for
|
||||
the CustomActivator value under this key). Without this
|
||||
the service generates a fresh per-install UUID, which
|
||||
diverges from the ToastActivatorCLSID set on the Start
|
||||
Menu / Desktop shortcuts above and the COM activator
|
||||
never fires when a toast is clicked. -->
|
||||
<RegistryValue Name="CustomActivator" Type="string" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
||||
</RegistryKey>
|
||||
</Component>
|
||||
</StandardDirectory>
|
||||
@@ -85,10 +90,40 @@
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" TerminateProcess="0" />
|
||||
|
||||
<!-- WebView2 evergreen runtime detection.
|
||||
Probe both the per-machine and per-user EdgeUpdate keys; if either
|
||||
reports a non-empty `pv` value the runtime is already installed
|
||||
and we skip the bootstrapper. -->
|
||||
<Property Id="WEBVIEW2_VERSION_HKLM">
|
||||
<RegistrySearch Id="WV2HKLM" Root="HKLM"
|
||||
Key="SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
Name="pv" Type="raw" Bitness="always64" />
|
||||
</Property>
|
||||
<Property Id="WEBVIEW2_VERSION_HKCU">
|
||||
<RegistrySearch Id="WV2HKCU" Root="HKCU"
|
||||
Key="Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
Name="pv" Type="raw" />
|
||||
</Property>
|
||||
|
||||
<!-- Embed the bootstrapper payload. Path is relative to the WiX
|
||||
working directory; sign-pipelines stages it next to client/
|
||||
via `wails3 generate webview2bootstrapper`. -->
|
||||
<Binary Id="WebView2Bootstrapper" SourceFile=".\client\MicrosoftEdgeWebview2Setup.exe" />
|
||||
|
||||
<CustomAction Id="InstallWebView2"
|
||||
BinaryRef="WebView2Bootstrapper"
|
||||
ExeCommand="/silent /install"
|
||||
Execute="deferred"
|
||||
Impersonate="no"
|
||||
Return="check" />
|
||||
|
||||
<InstallExecuteSequence>
|
||||
<Custom Action="InstallWebView2" Before="InstallFinalize"
|
||||
Condition="NOT WEBVIEW2_VERSION_HKLM AND NOT WEBVIEW2_VERSION_HKCU AND NOT REMOVE" />
|
||||
</InstallExecuteSequence>
|
||||
|
||||
<!-- Icons -->
|
||||
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" />
|
||||
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\build\windows\icon.ico" />
|
||||
<Property Id="ARPPRODUCTICON" Value="NetbirdIcon" />
|
||||
|
||||
</Package>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,6 +24,12 @@ service DaemonService {
|
||||
// Status of the service.
|
||||
rpc Status(StatusRequest) returns (StatusResponse) {}
|
||||
|
||||
// SubscribeStatus pushes a fresh StatusResponse on connection state
|
||||
// changes (Connected / Disconnected / Connecting / address change /
|
||||
// peers list change). The first message on the stream is the current
|
||||
// snapshot, so a freshly-subscribed UI doesn't need to also call Status.
|
||||
rpc SubscribeStatus(StatusRequest) returns (stream StatusResponse) {}
|
||||
|
||||
// Down stops engine work in the daemon.
|
||||
rpc Down(DownRequest) returns (DownResponse) {}
|
||||
|
||||
@@ -109,6 +115,25 @@ service DaemonService {
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||
|
||||
// RequestExtendAuthSession initiates an SSO session-extension flow.
|
||||
// The daemon prepares a PKCE/device-code request against the IdP and
|
||||
// returns the verification URI; the UI is expected to open it. The flow
|
||||
// state is kept in the daemon until WaitExtendAuthSession completes it.
|
||||
rpc RequestExtendAuthSession(RequestExtendAuthSessionRequest) returns (RequestExtendAuthSessionResponse) {}
|
||||
|
||||
// WaitExtendAuthSession blocks until the user finishes the SSO step
|
||||
// started by RequestExtendAuthSession, then forwards the resulting JWT
|
||||
// to the management server's ExtendAuthSession RPC. Returns the new
|
||||
// session expiry deadline. The tunnel stays up the entire time.
|
||||
rpc WaitExtendAuthSession(WaitExtendAuthSessionRequest) returns (WaitExtendAuthSessionResponse) {}
|
||||
|
||||
// DismissSessionWarning records that the user clicked "Dismiss" on the
|
||||
// T-WarningLead interactive notification, suppressing the auto-opened
|
||||
// SessionAboutToExpire dialog that would otherwise fire at
|
||||
// T-FinalWarningLead for the current deadline. Idempotent and best-effort:
|
||||
// a missed call only means the fallback dialog will still appear.
|
||||
rpc DismissSessionWarning(DismissSessionWarningRequest) returns (DismissSessionWarningResponse) {}
|
||||
|
||||
// StartCPUProfile starts CPU profiling in the daemon
|
||||
rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {}
|
||||
|
||||
@@ -227,6 +252,12 @@ message UpRequest {
|
||||
optional string profileName = 1;
|
||||
optional string username = 2;
|
||||
reserved 3;
|
||||
// async instructs the daemon to start the connection attempt and return
|
||||
// immediately without waiting for the engine to become ready. Status updates
|
||||
// are delivered via the SubscribeStatus stream. When false (the default) the
|
||||
// RPC blocks until the engine is running or gives up, which is the behaviour
|
||||
// needed by the CLI.
|
||||
bool async = 4;
|
||||
}
|
||||
|
||||
message UpResponse {}
|
||||
@@ -244,6 +275,10 @@ message StatusResponse{
|
||||
FullStatus fullStatus = 2;
|
||||
// NetBird daemon version
|
||||
string daemonVersion = 3;
|
||||
// Absolute UTC instant at which the peer's SSO session expires.
|
||||
// Unset when the peer is not SSO-registered or login expiration is disabled.
|
||||
// The UI derives "warning active" from this value and its own clock.
|
||||
google.protobuf.Timestamp sessionExpiresAt = 4;
|
||||
}
|
||||
|
||||
message DownRequest {}
|
||||
@@ -314,13 +349,6 @@ message GetConfigResponse {
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool disable_ipv6 = 27;
|
||||
|
||||
// mDMManagedFields lists the names of configuration keys whose value is
|
||||
// currently enforced by an MDM policy. Names match mdm.Key* constants
|
||||
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
|
||||
// render the corresponding inputs as read-only and display a "managed
|
||||
// by MDM" indicator.
|
||||
repeated string mDMManagedFields = 28;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -416,6 +444,12 @@ message FullStatus {
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
|
||||
// networksRevision bumps whenever the set of routed networks (route and
|
||||
// exit-node candidates) or their selected state changes. The UI fingerprints
|
||||
// on it to know when to re-fetch ListNetworks via the push stream, instead
|
||||
// of polling on every status snapshot.
|
||||
uint64 networksRevision = 11;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -740,15 +774,6 @@ message GetFeaturesResponse{
|
||||
bool disable_networks = 3;
|
||||
}
|
||||
|
||||
// MDMManagedFieldsViolation is attached as a gRPC error detail on a
|
||||
// FailedPrecondition status returned from SetConfig (and similar mutating
|
||||
// RPCs) when the caller tries to modify one or more MDM-enforced fields.
|
||||
// The fields list contains the offending key names; the entire request is
|
||||
// rejected (no partial apply).
|
||||
message MDMManagedFieldsViolation {
|
||||
repeated string fields = 1;
|
||||
}
|
||||
|
||||
message TriggerUpdateRequest {}
|
||||
|
||||
message TriggerUpdateResponse {
|
||||
@@ -816,6 +841,55 @@ message WaitJWTTokenResponse {
|
||||
int64 expiresIn = 3;
|
||||
}
|
||||
|
||||
// RequestExtendAuthSessionRequest kicks off the session-extension SSO flow.
|
||||
message RequestExtendAuthSessionRequest {
|
||||
// Optional OIDC login_hint (typically the user's email) to pre-fill the
|
||||
// IdP login form.
|
||||
optional string hint = 1;
|
||||
}
|
||||
|
||||
// RequestExtendAuthSessionResponse carries the verification URI the UI
|
||||
// should open in a browser. The daemon retains the flow state and resolves
|
||||
// it via WaitExtendAuthSession.
|
||||
message RequestExtendAuthSessionResponse {
|
||||
// verification URI for the user to open in the browser
|
||||
string verificationURI = 1;
|
||||
// complete verification URI (with embedded user code)
|
||||
string verificationURIComplete = 2;
|
||||
// user code to enter on verification URI (for device-code flows)
|
||||
string userCode = 3;
|
||||
// device code for matching the WaitExtendAuthSession call to this flow
|
||||
string deviceCode = 4;
|
||||
// expiration time in seconds for the device code / PKCE flow
|
||||
int64 expiresIn = 5;
|
||||
}
|
||||
|
||||
// WaitExtendAuthSessionRequest is sent by the UI after it opens the
|
||||
// verification URI. The daemon blocks on this call until the user
|
||||
// completes (or aborts) the SSO step.
|
||||
message WaitExtendAuthSessionRequest {
|
||||
// device code returned by RequestExtendAuthSession
|
||||
string deviceCode = 1;
|
||||
// user code for verification
|
||||
string userCode = 2;
|
||||
}
|
||||
|
||||
// WaitExtendAuthSessionResponse carries the refreshed deadline returned
|
||||
// by the management server. Unset when the management server reports the
|
||||
// peer is not eligible for session extension.
|
||||
message WaitExtendAuthSessionResponse {
|
||||
google.protobuf.Timestamp sessionExpiresAt = 1;
|
||||
}
|
||||
|
||||
// DismissSessionWarningRequest is sent by the UI when the user clicks
|
||||
// "Dismiss" on the T-WarningLead notification.
|
||||
message DismissSessionWarningRequest {}
|
||||
|
||||
// DismissSessionWarningResponse acknowledges the dismissal. Carries no
|
||||
// payload — the daemon's only obligation is to silence the upcoming
|
||||
// T-FinalWarningLead fallback for the current deadline.
|
||||
message DismissSessionWarningResponse {}
|
||||
|
||||
// StartCPUProfileRequest for starting CPU profiling
|
||||
message StartCPUProfileRequest {}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -53,7 +53,10 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
if engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
engine.RunHealthProbes(true)
|
||||
// Background ctx: the bundle wants a full, fresh probe regardless
|
||||
// of the DebugBundle RPC client's lifetime. The engine's own ctx
|
||||
// still aborts it on shutdown.
|
||||
engine.RunHealthProbes(context.Background(), true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,419 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// preSharedKeyRedactedSentinel is the value GetConfig returns in place
|
||||
// of an actual PSK, so a UI that round-trips the field back to the
|
||||
// daemon (via SetConfig / Login) can be distinguished from a deliberate
|
||||
// override. Any incoming PSK that equals this sentinel is treated as
|
||||
// a no-op echo, never as a conflict with the policy.
|
||||
const preSharedKeyRedactedSentinel = "**********"
|
||||
|
||||
// loadMDMPolicy is the indirection used by server handlers to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// conflictCheck is a value-aware comparison between a single field in
|
||||
// the incoming request and the corresponding MDM-enforced value. It
|
||||
// runs only when the field was actually set in the request (presence
|
||||
// already filtered upstream); ok=true reports the policy value, ok=false
|
||||
// means the policy is silent on the key — both are treated as conflicts
|
||||
// to be safe (an MDM key declared as managed must hold a value).
|
||||
type conflictCheck struct {
|
||||
key string
|
||||
check func(*mdm.Policy) (match bool)
|
||||
}
|
||||
|
||||
// onMDMPolicyChange is invoked by the MDM reload ticker every time the
|
||||
// OS-native managed-config store reports a diff vs the last observation.
|
||||
//
|
||||
// Restart sequence:
|
||||
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
|
||||
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
|
||||
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// applyMDMPolicy with the freshly loaded Policy).
|
||||
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
|
||||
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// RPC) can refresh its cached config view without polling.
|
||||
//
|
||||
// The callback runs in the ticker's own goroutine. Ticker has already
|
||||
// logged the per-key diff before invoking this hook.
|
||||
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
|
||||
log.Warn("MDM policy changed; restarting engine to apply new configuration")
|
||||
|
||||
// Hold s.mutex for the entire restart sequence (cancel + quiescence
|
||||
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
|
||||
// MDM is restarting blocks on the Lock until we are done — they
|
||||
// then observe the post-restart state coherently. This is safe
|
||||
// because the connectWithRetryRuns goroutine no longer acquires
|
||||
// s.mutex in its defer (intent vs. goroutine-alive concerns are
|
||||
// fully separated; see the connectionGoroutineRunning helper).
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.clientRunning {
|
||||
// The client is not running, so there's no engine to restart.
|
||||
return nil
|
||||
}
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
// Wait for previous connectWithRetryRuns to exit so we don't end up
|
||||
// with two goroutines fighting over the same status recorder + engine.
|
||||
// The teardown engages a fan-out of engine goroutines (peer workers,
|
||||
// signal handler, route manager, ...). close(clientGiveUpChan)
|
||||
// happens in the function-scope defer of connectWithRetryRuns, on
|
||||
// every exit path (ctx cancel, backoff exhausted, panic) — see the
|
||||
// defer in server.go.
|
||||
if s.clientGiveUpChan != nil {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
case <-time.After(10 * time.Second):
|
||||
return fmt.Errorf("failed to restart the engine due to timeout")
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.restartEngineForMDMLocked(); err != nil {
|
||||
log.Errorf("MDM restart failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent has already fired inside
|
||||
// restartEngineForMDMLocked with source="mdm". Emit an MDM-specific
|
||||
// user-visible toast so the operator knows their IT policy was
|
||||
// applied (UserMessage != "" triggers the GUI notifier).
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
"MDM policy applied",
|
||||
"NetBird configuration was updated by your IT policy.",
|
||||
map[string]string{"source": "mdm", "type": "policy_applied"},
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent broadcasts a SystemEvent informing any active
|
||||
// SubscribeEvents subscriber (typically the GUI tray) that the daemon's
|
||||
// effective Config has been replaced and any cached client-side view
|
||||
// should be refreshed. Callers pass a stable `source` label so the GUI
|
||||
// can distinguish a startup spawn from a user-triggered Up or an
|
||||
// MDM-driven restart. Reusing the SYSTEM category keeps the proto enum
|
||||
// stable; metadata.type="config_changed" routes to the GUI's refresh
|
||||
// handler. UserMessage is left empty so the system tray does not toast
|
||||
// for every internal restart; the MDM path emits a separate
|
||||
// "policy_applied" event (with UserMessage) for that purpose.
|
||||
func (s *Server) publishConfigChangedEvent(source string) {
|
||||
if s.statusRecorder == nil {
|
||||
return
|
||||
}
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
fmt.Sprintf("daemon config changed (source=%s)", source),
|
||||
"",
|
||||
map[string]string{
|
||||
"source": source,
|
||||
"type": "config_changed",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngineForMDMLocked re-resolves the active profile config
|
||||
// (re-running applyMDMPolicy via Config.apply) and re-spawns
|
||||
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
|
||||
// MDM change behaves identically to a fresh boot under the new policy.
|
||||
//
|
||||
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
|
||||
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
|
||||
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
|
||||
// state.
|
||||
func (s *Server) restartEngineForMDMLocked() error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile state: %w", err)
|
||||
}
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile config: %w", err)
|
||||
}
|
||||
|
||||
s.config = config
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("mdm")
|
||||
return nil
|
||||
}
|
||||
|
||||
// conflictBool builds a conflictCheck for a boolean MDM key. If p is nil
|
||||
// the field is treated as matching (no override requested); otherwise the
|
||||
// check returns true only when the policy contains the key and its
|
||||
// boolean value equals *p.
|
||||
func conflictBool(key string, p *bool) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true // absent → match by definition
|
||||
}
|
||||
want, ok := pol.GetBool(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictString builds a conflictCheck for a string MDM key. An empty
|
||||
// `got` is treated as "field not set" (no override requested); otherwise
|
||||
// the check returns true only when the policy contains the key and its
|
||||
// value equals got.
|
||||
func conflictString(key, got string) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if got == "" {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetString(key)
|
||||
return ok && want == got
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictInt64 builds a conflictCheck for an integer MDM key. If p is
|
||||
// nil the field is treated as matching; otherwise the check returns
|
||||
// true only when the policy contains the key and its int value equals *p.
|
||||
func conflictInt64(key string, p *int64) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetInt(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolveConflicts walks the per-field checks against the active MDM
|
||||
// policy and returns the names of keys whose requested value diverges
|
||||
// from the policy-enforced value. Keys not present in the policy are
|
||||
// skipped silently (the gate fires only for keys the admin has
|
||||
// actually pushed). Returns nil for an empty policy.
|
||||
func resolveConflicts(policy *mdm.Policy, checks []conflictCheck) []string {
|
||||
if policy.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
var conflicts []string
|
||||
for _, c := range checks {
|
||||
if !policy.HasKey(c.key) {
|
||||
continue
|
||||
}
|
||||
if !c.check(policy) {
|
||||
conflicts = append(conflicts, c.key)
|
||||
}
|
||||
}
|
||||
return conflicts
|
||||
}
|
||||
|
||||
// mdmManagedFieldConflicts returns the names of MDM-managed keys whose
|
||||
// requested value in the SetConfigRequest differs from the MDM-enforced
|
||||
// value. A field set to the same value the policy already enforces is
|
||||
// treated as a no-op echo (the GUI tray sends a full Config snapshot on
|
||||
// every toggle, so most fields in a typical request match the policy
|
||||
// exactly and must NOT be flagged as conflicts). The redacted PSK
|
||||
// sentinel ("**********") returned by GetConfig is recognised and
|
||||
// treated as no-op so the UI can safely round-trip it.
|
||||
func mdmManagedFieldConflicts(msg *proto.SetConfigRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PSK round-trip echo: collapse the sentinel to empty so the
|
||||
// shared check treats it as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != preSharedKeyRedactedSentinel {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// setConfigRequestHasConfigOverrides reports whether the SetConfigRequest
|
||||
// carries ANY field that would actually mutate the persisted config.
|
||||
// The CLI builds a SetConfigRequest unconditionally on every
|
||||
// `netbird up` (see setupSetConfigReq in cmd/up.go) — a plain
|
||||
// `netbird up` produces a request with every field at its zero value;
|
||||
// the gate must skip such no-op invocations or it would always fire
|
||||
// even when the user did not pass any --flag. Returns false on a nil
|
||||
// msg; true when any management/admin URL, PSK, DNS/NAT list+clean
|
||||
// flag, interface/port/MTU, or any optional bool/duration field is set.
|
||||
func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.Mtu != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil ||
|
||||
msg.DisableIpv6 != nil ||
|
||||
msg.EnableSSHRoot != nil ||
|
||||
msg.EnableSSHSFTP != nil ||
|
||||
msg.EnableSSHLocalPortForwarding != nil ||
|
||||
msg.EnableSSHRemotePortForwarding != nil ||
|
||||
msg.DisableSSHAuth != nil ||
|
||||
msg.SshJWTCacheTTL != nil
|
||||
}
|
||||
|
||||
// loginRequestHasConfigOverrides reports whether the LoginRequest
|
||||
// carries ANY field that would mutate persisted daemon configuration
|
||||
// (as opposed to pure-auth fields like setupKey, hostname, hint,
|
||||
// profileName, username). Used by the Login handler to decide whether
|
||||
// the `--disable-update-settings` / MDM gates must run: a re-auth that
|
||||
// changes nothing about the configuration is always allowed.
|
||||
func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.PreSharedKey != "" || //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil
|
||||
}
|
||||
|
||||
// loginRequestMDMConflicts mirrors mdmManagedFieldConflicts but for the
|
||||
// LoginRequest surface. Same value-aware semantics: a field set to the
|
||||
// MDM-enforced value is a no-op echo, not a conflict; only a divergent
|
||||
// value is flagged. PSK has two proto fields — PreSharedKey (deprecated)
|
||||
// and OptionalPreSharedKey (current); either route trips the gate if it
|
||||
// diverges from the MDM-enforced PSK. OptionalPreSharedKey wins when
|
||||
// both are set; the redaction sentinel ("**********") is accepted as
|
||||
// a no-op echo.
|
||||
func loginRequestMDMConflicts(msg *proto.LoginRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collapse the two PSK fields + the redaction sentinel down to a
|
||||
// single "got" string the shared check can compare against the
|
||||
// policy: OptionalPreSharedKey wins if set; PreSharedKey (deprecated)
|
||||
// is the fallback; sentinel echo is treated as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
} else if msg.PreSharedKey != "" { //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
pskGot = msg.PreSharedKey //nolint:staticcheck // SA1019
|
||||
}
|
||||
if pskGot == preSharedKeyRedactedSentinel {
|
||||
pskGot = ""
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// rejectMDMManagedFieldConflicts returns a FailedPrecondition gRPC error
|
||||
// with an MDMManagedFieldsViolation detail when any of the requested
|
||||
// fields tries to change an MDM-enforced value to something else, and
|
||||
// nil otherwise. The whole request is rejected on any conflict; non-
|
||||
// conflicting fields in the same request are not applied either (no
|
||||
// partial apply).
|
||||
func rejectMDMManagedFieldConflicts(conflicts []string) error {
|
||||
if len(conflicts) == 0 {
|
||||
return nil
|
||||
}
|
||||
log.Warnf("MDM rejected request: tried to modify %d managed key(s): %v",
|
||||
len(conflicts), conflicts)
|
||||
st := gstatus.New(
|
||||
codes.FailedPrecondition,
|
||||
fmt.Sprintf("fields managed by MDM cannot be modified: %v", conflicts),
|
||||
)
|
||||
detailed, err := st.WithDetails(&proto.MDMManagedFieldsViolation{Fields: conflicts})
|
||||
if err != nil {
|
||||
// Detail attachment is best-effort; fall back to the plain status
|
||||
// so the caller still gets a usable FailedPrecondition.
|
||||
return st.Err()
|
||||
}
|
||||
return detailed.Err()
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -172,6 +172,17 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
|
||||
return nil, fmt.Errorf("select routes: %w", err)
|
||||
}
|
||||
|
||||
// Exit nodes are mutually exclusive: if this selection activates an
|
||||
// exit node, deselect every other available exit node so two can't be
|
||||
// selected at once. Non-exit route selections are left untouched.
|
||||
if requestActivatesExitNode(routes, routesMap) {
|
||||
if others := otherExitNodeIDs(routesMap, routes); len(others) > 0 {
|
||||
if err := routeSelector.DeselectRoutes(others, netIdRoutes); err != nil {
|
||||
return nil, fmt.Errorf("deselect sibling exit nodes: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||
|
||||
@@ -195,7 +206,7 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -249,3 +260,38 @@ func toNetIDs(routes []string) []route.NetID {
|
||||
}
|
||||
return netIDs
|
||||
}
|
||||
|
||||
func isExitNodeRoutes(routes []*route.Route) bool {
|
||||
return len(routes) > 0 && (route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network))
|
||||
}
|
||||
|
||||
// requestActivatesExitNode reports whether any requested NetID maps to an exit
|
||||
// node (default route) in the current route table.
|
||||
func requestActivatesExitNode(requested []route.NetID, routesMap map[route.NetID][]*route.Route) bool {
|
||||
for _, id := range requested {
|
||||
if isExitNodeRoutes(routesMap[id]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// otherExitNodeIDs returns every available exit-node NetID that is not in the
|
||||
// requested set — the siblings to deselect so a single exit node stays active.
|
||||
func otherExitNodeIDs(routesMap map[route.NetID][]*route.Route, requested []route.NetID) []route.NetID {
|
||||
keep := make(map[route.NetID]struct{}, len(requested))
|
||||
for _, id := range requested {
|
||||
keep[id] = struct{}{}
|
||||
}
|
||||
var others []route.NetID
|
||||
for id, routes := range routesMap {
|
||||
if !isExitNodeRoutes(routes) {
|
||||
continue
|
||||
}
|
||||
if _, ok := keep[id]; ok {
|
||||
continue
|
||||
}
|
||||
others = append(others, id)
|
||||
}
|
||||
return others
|
||||
}
|
||||
|
||||
26
client/server/network_exitnode_test.go
Normal file
26
client/server/network_exitnode_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestExitNodeSelectionHelpers(t *testing.T) {
|
||||
routesMap := map[route.NetID][]*route.Route{
|
||||
"exitA": {{Network: netip.MustParsePrefix("0.0.0.0/0")}},
|
||||
"exitB": {{Network: netip.MustParsePrefix("::/0")}},
|
||||
"lan": {{Network: netip.MustParsePrefix("192.168.0.0/16")}},
|
||||
}
|
||||
|
||||
assert.True(t, requestActivatesExitNode([]route.NetID{"exitA"}, routesMap), "v4 default route is an exit node")
|
||||
assert.True(t, requestActivatesExitNode([]route.NetID{"exitB"}, routesMap), "v6 default route is an exit node")
|
||||
assert.False(t, requestActivatesExitNode([]route.NetID{"lan"}, routesMap), "lan route is not an exit node")
|
||||
assert.False(t, requestActivatesExitNode([]route.NetID{"missing"}, routesMap), "unknown id is not an exit node")
|
||||
|
||||
others := otherExitNodeIDs(routesMap, []route.NetID{"exitB"})
|
||||
assert.ElementsMatch(t, []route.NetID{"exitA"}, others, "only the other exit node is a sibling; the lan route is ignored")
|
||||
}
|
||||
88
client/server/probe_throttle.go
Normal file
88
client/server/probe_throttle.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// healthProbeRunner runs the full, expensive probe (network round-trips to
|
||||
// management, signal and the relays) and reports whether every component was
|
||||
// healthy. ctx cancels the probe when the caller gives up. Satisfied by
|
||||
// *internal.Engine.
|
||||
type healthProbeRunner interface {
|
||||
RunHealthProbes(ctx context.Context, waitForResult bool) bool
|
||||
}
|
||||
|
||||
// statsRefresher does the cheap WireGuard-stats refresh callers fall back to
|
||||
// when a fresh probe isn't warranted. Satisfied by *peer.Status.
|
||||
type statsRefresher interface {
|
||||
RefreshWireGuardStats() error
|
||||
}
|
||||
|
||||
// probeThrottle rate-limits and single-flights the daemon's health probes.
|
||||
//
|
||||
// Health probes are expensive (network round-trips to management, signal and
|
||||
// the relays), while Status(GetFullPeerStatus=true) RPCs can arrive frequently
|
||||
// and concurrently — the desktop UI alone issues one per connect/disconnect.
|
||||
// probeThrottle keeps that load bounded with two rules:
|
||||
//
|
||||
// - Single-flight: only one probe runs at a time. Callers that pile up while
|
||||
// a probe is in flight share its result instead of each launching another,
|
||||
// even when that probe failed. A failed probe therefore does not make every
|
||||
// waiter re-probe in turn; the next, non-overlapping caller can try again.
|
||||
// - Throttle: after a fully successful probe the result is cached for
|
||||
// interval. While any component is unhealthy the cache is not advanced, so
|
||||
// later callers keep probing frequently and notice recovery quickly — the
|
||||
// intentional "probe often while unhealthy" behaviour from the original
|
||||
// design.
|
||||
type probeThrottle struct {
|
||||
interval time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
lastOK time.Time // last fully-successful probe; drives the throttle window
|
||||
completedAt time.Time // when the most recent probe finished; drives single-flight sharing
|
||||
}
|
||||
|
||||
func newProbeThrottle(interval time.Duration) *probeThrottle {
|
||||
return &probeThrottle{interval: interval}
|
||||
}
|
||||
|
||||
// Run decides whether to run a fresh health probe or serve the most recent
|
||||
// result. It serialises concurrent callers: at most one runner.RunHealthProbes
|
||||
// executes at a time and the rest call refresher.RefreshWireGuardStats and read
|
||||
// the snapshot it produced.
|
||||
//
|
||||
// Both calls run while the throttle's lock is held, so a slow probe blocks
|
||||
// other callers until it completes — that blocking is the single-flight
|
||||
// guarantee. ctx is forwarded to RunHealthProbes so a caller that gives up
|
||||
// cancels the in-flight probe (and any caller still queued on the lock falls
|
||||
// through quickly once it acquires it, since the probe ctx is already done).
|
||||
func (t *probeThrottle) Run(ctx context.Context, runner healthProbeRunner, refresher statsRefresher, waitForResult bool) {
|
||||
entered := time.Now()
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// A probe that finished after we entered ran while we were waiting on the
|
||||
// lock — i.e. a peer in the same burst already probed for us, so share its
|
||||
// result rather than launch another. This holds even when that probe
|
||||
// failed, so a failed probe doesn't make every waiter re-probe in turn.
|
||||
sharedRecentProbe := t.completedAt.After(entered)
|
||||
throttled := time.Since(t.lastOK) <= t.interval
|
||||
|
||||
if sharedRecentProbe || throttled {
|
||||
if err := refresher.RefreshWireGuardStats(); err != nil {
|
||||
log.Debugf("failed to refresh WireGuard stats: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
healthy := runner.RunHealthProbes(ctx, waitForResult)
|
||||
t.completedAt = time.Now()
|
||||
if healthy {
|
||||
t.lastOK = t.completedAt
|
||||
}
|
||||
}
|
||||
109
client/server/probe_throttle_test.go
Normal file
109
client/server/probe_throttle_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeProber implements both healthProbeRunner and statsRefresher with
|
||||
// caller-supplied behaviour.
|
||||
type fakeProber struct {
|
||||
onProbe func() bool
|
||||
onRefresh func()
|
||||
}
|
||||
|
||||
func (f fakeProber) RunHealthProbes(context.Context, bool) bool {
|
||||
return f.onProbe()
|
||||
}
|
||||
|
||||
func (f fakeProber) RefreshWireGuardStats() error {
|
||||
if f.onRefresh != nil {
|
||||
f.onRefresh()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestProbeThrottle_CachesAfterSuccess(t *testing.T) {
|
||||
pt := newProbeThrottle(time.Minute)
|
||||
|
||||
var probes, refreshes int
|
||||
prober := fakeProber{
|
||||
onProbe: func() bool { probes++; return true },
|
||||
onRefresh: func() { refreshes++ },
|
||||
}
|
||||
|
||||
pt.Run(context.Background(), prober, prober, false)
|
||||
pt.Run(context.Background(), prober, prober, false)
|
||||
|
||||
if probes != 1 {
|
||||
t.Fatalf("expected 1 probe within the throttle window, got %d", probes)
|
||||
}
|
||||
if refreshes != 1 {
|
||||
t.Fatalf("expected the throttled caller to refresh stats once, got %d", refreshes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeThrottle_StaysOpenWhileUnhealthy(t *testing.T) {
|
||||
pt := newProbeThrottle(time.Minute)
|
||||
|
||||
var probes int
|
||||
prober := fakeProber{onProbe: func() bool { probes++; return false }} // never healthy
|
||||
|
||||
// Sequential, non-overlapping callers must each re-probe while unhealthy:
|
||||
// a failed probe does not advance the throttle window.
|
||||
pt.Run(context.Background(), prober, prober, false)
|
||||
pt.Run(context.Background(), prober, prober, false)
|
||||
pt.Run(context.Background(), prober, prober, false)
|
||||
|
||||
if probes != 3 {
|
||||
t.Fatalf("expected every non-overlapping caller to probe while unhealthy, got %d", probes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeThrottle_SingleFlightSharesResult(t *testing.T) {
|
||||
pt := newProbeThrottle(time.Minute)
|
||||
|
||||
var probes int32
|
||||
release := make(chan struct{})
|
||||
started := make(chan struct{})
|
||||
|
||||
// First caller blocks inside the probe until released, holding the lock so
|
||||
// the others pile up behind it.
|
||||
prober := fakeProber{onProbe: func() bool {
|
||||
if atomic.AddInt32(&probes, 1) == 1 {
|
||||
close(started)
|
||||
<-release
|
||||
}
|
||||
return false // unhealthy — the share must happen regardless of result
|
||||
}}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pt.Run(context.Background(), prober, prober, false)
|
||||
}()
|
||||
|
||||
<-started // ensure the first probe is in flight before the burst arrives
|
||||
|
||||
const waiters = 9
|
||||
wg.Add(waiters)
|
||||
for i := 0; i < waiters; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pt.Run(context.Background(), prober, prober, false)
|
||||
}()
|
||||
}
|
||||
|
||||
// Give the waiters time to block on the lock, then let the first finish.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
close(release)
|
||||
wg.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(&probes); got != 1 {
|
||||
t.Fatalf("expected a concurrent burst to run exactly 1 probe, got %d", got)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user