Compare commits
174 Commits
fix/rosenp
...
ui-refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f957ff41a | ||
|
|
598fcbd817 | ||
|
|
17a365926d | ||
|
|
577ce6deb5 | ||
|
|
580cfa0dc5 | ||
|
|
8d4f35352f | ||
|
|
85029898a5 | ||
|
|
c3aeb5be15 | ||
|
|
df61f22d96 | ||
|
|
32df29bbd4 | ||
|
|
0a458ead8b | ||
|
|
aab8274b1a | ||
|
|
d3b660afba | ||
|
|
341848b1ae | ||
|
|
414e7815e4 | ||
|
|
ef6b4f7538 | ||
|
|
a7b26e3c0d | ||
|
|
42534b24c5 | ||
|
|
2aea1f7bb5 | ||
|
|
620233a7ac | ||
|
|
1c15e9976b | ||
|
|
f04e2bada8 | ||
|
|
1d88faf66f | ||
|
|
84093af1f0 | ||
|
|
34a4744565 | ||
|
|
b79b62bee4 | ||
|
|
bec4eb326a | ||
|
|
8748f3810d | ||
|
|
1c5254cb31 | ||
|
|
3f8cd29006 | ||
|
|
af24fd7796 | ||
|
|
13d32d274f | ||
|
|
ca48de549e | ||
|
|
5b71a4f2ad | ||
|
|
741ce8581d | ||
|
|
705f87fc20 | ||
|
|
6b44d65cac | ||
|
|
f84b1df857 | ||
|
|
c24349e4f1 | ||
|
|
7f7bee630f | ||
|
|
4e0eb9f2d4 | ||
|
|
38a367e0cd | ||
|
|
78fb15e327 | ||
|
|
35e58a2796 | ||
|
|
a6278936af | ||
|
|
32f62f3ed8 | ||
|
|
7fae703a27 | ||
|
|
f468f15a30 | ||
|
|
3f91f49277 | ||
|
|
347c5bf317 | ||
|
|
22e2519d71 | ||
|
|
e916f12cca | ||
|
|
9ed2e2a5b4 | ||
|
|
2ccae7ec47 | ||
|
|
5bdccfe8f4 | ||
|
|
cccb0e9230 | ||
|
|
9d8eb76746 | ||
|
|
1ebb507cbb | ||
|
|
5411fa4350 | ||
|
|
17cae1a75c | ||
|
|
c0b0eeb6ab | ||
|
|
d32721d7fc | ||
|
|
288f8dec08 | ||
|
|
db8c9a0e30 | ||
|
|
505fcc7f7a | ||
|
|
07e5450117 | ||
|
|
3f914090cb | ||
|
|
ea9fab4396 | ||
|
|
0fe8764707 | ||
|
|
c0e7c61c4b | ||
|
|
e4eedbe18f | ||
|
|
fc1db63fc3 | ||
|
|
d841a6aa07 | ||
|
|
77b479286e | ||
|
|
ab2a8794e7 | ||
|
|
258e7ec038 | ||
|
|
1932b76f5b | ||
|
|
d33b841a33 | ||
|
|
df1935da6d | ||
|
|
eb6be5a2f3 | ||
|
|
209f14fc2f | ||
|
|
2bd56ecf67 | ||
|
|
67988c2407 | ||
|
|
53b2fb8dc1 | ||
|
|
803144e569 | ||
|
|
c0cd88a3d0 | ||
|
|
6c9b821bf0 | ||
|
|
83030dbbd6 | ||
|
|
1c8a6e3798 | ||
|
|
74ea03da9b | ||
|
|
77fdf23a50 | ||
|
|
1f4ed5c8ef | ||
|
|
e1bf362675 | ||
|
|
af40ee52f8 | ||
|
|
4988f2aa68 | ||
|
|
e3efaa5e59 | ||
|
|
100d25a062 | ||
|
|
04b4330393 | ||
|
|
c8e18585c6 | ||
|
|
1931a2c8a8 | ||
|
|
108d43e702 | ||
|
|
842ef0d657 | ||
|
|
439f44c6b4 | ||
|
|
b5a970155b | ||
|
|
686e0d97f2 | ||
|
|
0c287b6f4d | ||
|
|
f7f5946910 | ||
|
|
7a9f5a734f | ||
|
|
1aae067aaa | ||
|
|
28a7eba756 | ||
|
|
8841b950a2 | ||
|
|
0c2702c0d7 | ||
|
|
b43a09a1c7 | ||
|
|
595dfbb6f1 | ||
|
|
7f560df9be | ||
|
|
09052949a2 | ||
|
|
9aef31ff53 | ||
|
|
08f52f4517 | ||
|
|
18e3b5dd32 | ||
|
|
f3f9704c6f | ||
|
|
4c3d4effbd | ||
|
|
3953fee5a4 | ||
|
|
adeaa49cda | ||
|
|
2c5d52a1bf | ||
|
|
70a755fbae | ||
|
|
559da5d5b9 | ||
|
|
614ee11ac7 | ||
|
|
85080afa59 | ||
|
|
a5cc8da054 | ||
|
|
a4fd5a78b4 | ||
|
|
062a183e4e | ||
|
|
a2be41caf8 | ||
|
|
5b70989e3e | ||
|
|
d324a5ff48 | ||
|
|
debb558aa3 | ||
|
|
cce80f8276 | ||
|
|
05ee4e52b8 | ||
|
|
bb2bf673a0 | ||
|
|
91c745e5e8 | ||
|
|
68c38247f1 | ||
|
|
8b8f38de1b | ||
|
|
2b272e74c8 | ||
|
|
e6cbf30415 | ||
|
|
490b60ad0e | ||
|
|
553be144b4 | ||
|
|
c3f9514182 | ||
|
|
a8812d5fb1 | ||
|
|
6f93cf6ac3 | ||
|
|
18909390c2 | ||
|
|
b3eb5f2453 | ||
|
|
dc02542a9e | ||
|
|
0c136fffb9 | ||
|
|
fffb9dd219 | ||
|
|
93275f9052 | ||
|
|
dd9c15072f | ||
|
|
4c743bc03d | ||
|
|
2e61b42e92 | ||
|
|
3f8de2a149 | ||
|
|
bc609c3ae7 | ||
|
|
e3994d0c99 | ||
|
|
ba6e10cef3 | ||
|
|
ce53981b55 | ||
|
|
a69037630b | ||
|
|
df58935cc0 | ||
|
|
a1743dbf9b | ||
|
|
f9771de3f5 | ||
|
|
bfe19fa542 | ||
|
|
d07f25fc49 | ||
|
|
670b0f66ac | ||
|
|
15d73a2edd | ||
|
|
88a2bf582d | ||
|
|
0148d926d5 | ||
|
|
8f16a19b8f | ||
|
|
504dceedf3 |
10
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,13 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
|
||||
# which fails to compile until the frontend has been built. The Wails UI
|
||||
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
|
||||
# before goreleaser.
|
||||
# `go list -e` lets the listing succeed even though the embed fails to
|
||||
# resolve; the grep then drops the broken package by path. Without -e,
|
||||
# go list aborts with empty stdout and `go test` falls back to the repo
|
||||
# root, which has no Go files.
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
|
||||
|
||||
|
||||
16
.github/workflows/golang-test-linux.yml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
@@ -141,7 +141,7 @@ jobs:
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
@@ -154,7 +154,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
|
||||
# which fails to compile until the frontend has been built. The Wails UI
|
||||
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
|
||||
# before goreleaser.
|
||||
# `go list -e` lets the listing succeed even though the embed fails to
|
||||
# resolve; the grep then drops the broken package by path. Without -e,
|
||||
# go list aborts with empty stdout and `go test` falls back to the repo
|
||||
# root, which has no Go files.
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -214,7 +222,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -e -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
|
||||
9
.github/workflows/golang-test-windows.yml
vendored
@@ -64,8 +64,15 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- name: Generate test script
|
||||
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
|
||||
# which fails to compile until the frontend has been built. The Wails UI
|
||||
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
|
||||
# before goreleaser.
|
||||
# `go list -e` lets the listing succeed even though the embed fails to
|
||||
# resolve; the Where-Object pipeline then drops the broken package by
|
||||
# path. Without -e, go list aborts with empty stdout.
|
||||
run: |
|
||||
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||
$packages = go list -e ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } | Where-Object { $_ -notmatch '/client/ui' }
|
||||
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||
|
||||
13
.github/workflows/golangci-lint.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
skip: go.mod,go.sum,**/proxy/web/**,**/pnpm-lock.yaml,**/package-lock.json
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -50,7 +50,16 @@ jobs:
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Stub Wails frontend bundle
|
||||
# client/ui/main.go has //go:embed all:frontend/dist. The
|
||||
# directory is produced by `pnpm run build` and is gitignored, so
|
||||
# lint-only runs (no frontend toolchain) need a placeholder file
|
||||
# for the embed pattern to match.
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p client/ui/frontend/dist
|
||||
touch client/ui/frontend/dist/.embed-placeholder
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
with:
|
||||
|
||||
80
.github/workflows/release.yml
vendored
@@ -186,9 +186,9 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
run: goversioninfo -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
@@ -349,8 +349,18 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Set up pnpm
|
||||
uses: pnpm/action-setup@v3
|
||||
with:
|
||||
version: 9
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev gcc-mingw-w64-x86-64
|
||||
|
||||
- name: Decode GPG signing key
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
@@ -369,10 +379,16 @@ jobs:
|
||||
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Install wails3 CLI
|
||||
# Version derived from go.mod so the binding generator always matches
|
||||
# the wails runtime the binary links against.
|
||||
run: |
|
||||
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
|
||||
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
run: goversioninfo -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
@@ -439,6 +455,20 @@ jobs:
|
||||
run: go mod tidy
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
- name: Set up pnpm
|
||||
uses: pnpm/action-setup@v3
|
||||
with:
|
||||
version: 9
|
||||
- name: Install wails3 CLI
|
||||
# Version derived from go.mod so the binding generator always matches
|
||||
# the wails runtime the binary links against.
|
||||
run: |
|
||||
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
|
||||
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
@@ -528,24 +558,6 @@ jobs:
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
with:
|
||||
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
||||
file-name: mesa3d.7z
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
|
||||
|
||||
- name: Move opengl32.dll into dist (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: carlosperate/download-file-action@v2
|
||||
with:
|
||||
@@ -568,6 +580,28 @@ jobs:
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Set up Go for wails3 CLI
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Install wails3 CLI
|
||||
# Version derived from go.mod so the bootstrapper payload always
|
||||
# matches the wails runtime the binary links against.
|
||||
shell: bash
|
||||
run: |
|
||||
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
|
||||
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
|
||||
|
||||
- name: Stage WebView2 bootstrapper for installers
|
||||
# Both client/installer.nsis and client/netbird.wxs reference
|
||||
# client/MicrosoftEdgeWebview2Setup.exe. wails3 writes it there.
|
||||
# The signing pipeline (netbirdio/sign-pipelines) does the same
|
||||
# step for release builds; this mirrors it for PR sanity testing.
|
||||
shell: bash
|
||||
run: wails3 generate webview2bootstrapper -dir client
|
||||
|
||||
- name: Build NSIS installer
|
||||
uses: joncloud/makensis-action@v3.3
|
||||
with:
|
||||
|
||||
2
.github/workflows/wasm-build-validation.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
with:
|
||||
|
||||
@@ -114,6 +114,16 @@ linters:
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1012"
|
||||
# client/ui/main.go uses //go:embed all:frontend/dist; the
|
||||
# directory is populated by `pnpm build` in the release pipeline
|
||||
# and missing at lint time, so the embed parses to "no matching
|
||||
# files found" — surfaced by golangci-lint's typecheck pre-pass.
|
||||
# Suppress just that one diagnostic; the rest of the package
|
||||
# (services/, tray.go, grpc.go, ...) still gets linted normally.
|
||||
- linters:
|
||||
- typecheck
|
||||
path: client/ui/main\.go
|
||||
text: "pattern all:frontend/dist"
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
|
||||
before:
|
||||
hooks:
|
||||
# Bindings are gitignored; regenerate before the frontend build so
|
||||
# the @wailsio/runtime Vite plugin can resolve them (vite refuses to
|
||||
# build without them).
|
||||
- sh -c 'cd client/ui && wails3 generate bindings -clean=true -ts'
|
||||
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
|
||||
|
||||
builds:
|
||||
- id: netbird-ui
|
||||
dir: client/ui
|
||||
@@ -70,12 +79,15 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/build/netbird.desktop
|
||||
- src: client/ui/build/linux/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/assets/netbird.png
|
||||
- src: client/ui/build/appicon.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
- libgtk-3-0
|
||||
- libwebkit2gtk-4.1-0
|
||||
- libayatana-appindicator3-1
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
@@ -89,12 +101,15 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/build/netbird.desktop
|
||||
- src: client/ui/build/linux/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/assets/netbird.png
|
||||
- src: client/ui/build/appicon.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
- gtk3
|
||||
- webkit2gtk4.1
|
||||
- libayatana-appindicator-gtk3
|
||||
rpm:
|
||||
signature:
|
||||
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
|
||||
before:
|
||||
hooks:
|
||||
# Bindings are gitignored; regenerate before the frontend build so
|
||||
# the @wailsio/runtime Vite plugin can resolve them (vite refuses to
|
||||
# build without them).
|
||||
- sh -c 'cd client/ui && wails3 generate bindings -clean=true -ts'
|
||||
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
|
||||
|
||||
builds:
|
||||
- id: netbird-ui-darwin
|
||||
dir: client/ui
|
||||
@@ -20,8 +29,6 @@ builds:
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird-ui-darwin
|
||||
|
||||
@@ -22,11 +22,19 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// extendSessionFlag drives the `netbird login --extend` flow: refresh the
|
||||
// SSO session expiry on the management server without tearing down the
|
||||
// tunnel. Mutually exclusive with setup-key login (a setup-key cannot
|
||||
// refresh an SSO-tracked peer — see auth.errSetupKeyOnSSOExpiredPeer).
|
||||
var extendSessionFlag bool
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
loginCmd.PersistentFlags().BoolVar(&extendSessionFlag, "extend", false,
|
||||
"refresh the SSO session expiry without tearing down the tunnel (requires an active connection)")
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
@@ -61,6 +69,16 @@ var loginCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
if extendSessionFlag {
|
||||
if providedSetupKey != "" {
|
||||
return fmt.Errorf("--extend cannot be combined with a setup key; setup keys can only enrol new peers")
|
||||
}
|
||||
if err := doExtendSession(ctx, cmd); err != nil {
|
||||
return fmt.Errorf("extend session failed: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// workaround to run without service
|
||||
if util.FindFirstLogPath(logFiles) == "" {
|
||||
if err := doForegroundLogin(ctx, cmd, providedSetupKey, activeProf); err != nil {
|
||||
@@ -150,6 +168,65 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
return nil
|
||||
}
|
||||
|
||||
// doExtendSession drives the daemon's RequestExtendAuthSession /
|
||||
// WaitExtendAuthSession pair. The user is sent through a regular SSO flow
|
||||
// (browser + verification URL) and the resulting JWT is forwarded to the
|
||||
// management server's ExtendAuthSession RPC. The tunnel stays up
|
||||
// throughout — no Down/Up, no network-map resync.
|
||||
func doExtendSession(ctx context.Context, cmd *cobra.Command) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
req := &proto.RequestExtendAuthSessionRequest{}
|
||||
// Pre-fill the IdP login hint from the active profile so the user
|
||||
// doesn't have to retype their email. Best-effort: we still proceed
|
||||
// without a hint if the lookup fails.
|
||||
pm := profilemanager.NewProfileManager()
|
||||
if active, perr := pm.GetActiveProfile(); perr == nil {
|
||||
if profState, sperr := pm.GetProfileState(active.Name); sperr == nil && profState.Email != "" {
|
||||
req.Hint = &profState.Email
|
||||
}
|
||||
}
|
||||
|
||||
startResp, err := client.RequestExtendAuthSession(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("start extend session: %v", err)
|
||||
}
|
||||
|
||||
uri := startResp.GetVerificationURIComplete()
|
||||
if uri == "" {
|
||||
uri = startResp.GetVerificationURI()
|
||||
}
|
||||
openURL(cmd, uri, startResp.GetUserCode(), noBrowser, showQR)
|
||||
|
||||
waitResp, err := client.WaitExtendAuthSession(ctx, &proto.WaitExtendAuthSessionRequest{
|
||||
DeviceCode: startResp.GetDeviceCode(),
|
||||
UserCode: startResp.GetUserCode(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("wait for extend session: %v", err)
|
||||
}
|
||||
|
||||
if ts := waitResp.GetSessionExpiresAt(); ts.IsValid() && !ts.AsTime().IsZero() {
|
||||
deadline := ts.AsTime().Local()
|
||||
cmd.Printf("Session extended. New expiry: %s\n", deadline.Format("2006-01-02 15:04:05 MST"))
|
||||
} else {
|
||||
// Management reported the peer is not eligible (e.g. login
|
||||
// expiration disabled on the account). Surface that fact
|
||||
// instead of pretending the call succeeded.
|
||||
cmd.Println("Session extension call completed, but the management server did not return a new deadline (peer may not be SSO-tracked or login expiration is disabled).")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) (*profilemanager.Profile, error) {
|
||||
// switch profile if provided
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
@@ -43,16 +44,16 @@ func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
|
||||
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -117,6 +118,11 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var sessionExpiresAt time.Time
|
||||
if ts := resp.GetSessionExpiresAt(); ts.IsValid() {
|
||||
sessionExpiresAt = ts.AsTime().UTC()
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||
Anonymize: anonymizeFlag,
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
@@ -127,6 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
IPsFilter: ipsFilterMap,
|
||||
ConnectionTypeFilter: connectionTypeFilter,
|
||||
ProfileName: profName,
|
||||
SessionExpiresAt: sessionExpiresAt,
|
||||
})
|
||||
var statusOutputString string
|
||||
switch {
|
||||
|
||||
@@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||
if err != nil {
|
||||
@@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
listenAddr := net.JoinHostPort(addr.String(), port)
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
!define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
|
||||
!define INSTALLER_NAME "netbird-installer.exe"
|
||||
!define MAIN_APP_EXE "Netbird"
|
||||
!define ICON "ui\\assets\\netbird.ico"
|
||||
!define ICON "ui\\build\\windows\\icon.ico"
|
||||
!define BANNER "ui\\build\\banner.bmp"
|
||||
!define LICENSE_DATA "..\\LICENSE"
|
||||
|
||||
@@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
||||
|
||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
|
||||
; Create autostart registry entry based on checkbox
|
||||
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
|
||||
; or HKCU by legacy installers.
|
||||
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
|
||||
@@ -280,6 +288,43 @@ CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
SectionEnd
|
||||
|
||||
# Install the Microsoft Edge WebView2 runtime if it isn't already present.
|
||||
# Macro adapted from Wails3's NSIS template (wails_tools.nsh): a registry
|
||||
# probe followed by a silent install of the embedded evergreen bootstrapper.
|
||||
# The MicrosoftEdgeWebview2Setup.exe payload is staged next to this script
|
||||
# by the sign-pipelines build step (`wails3 generate webview2bootstrapper`).
|
||||
!macro nb.webview2runtime
|
||||
SetRegView 64
|
||||
# Per-machine install marker — populated when the runtime ships with
|
||||
# Edge or has been installed by an admin previously.
|
||||
ReadRegStr $0 HKLM "SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
|
||||
${If} $0 != ""
|
||||
Goto webview2_ok
|
||||
${EndIf}
|
||||
# Per-user fallback for HKCU installs.
|
||||
ReadRegStr $0 HKCU "Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
|
||||
${If} $0 != ""
|
||||
Goto webview2_ok
|
||||
${EndIf}
|
||||
|
||||
SetDetailsPrint both
|
||||
DetailPrint "Installing: WebView2 Runtime"
|
||||
SetDetailsPrint listonly
|
||||
|
||||
InitPluginsDir
|
||||
CreateDirectory "$pluginsdir\webview2bootstrapper"
|
||||
SetOutPath "$pluginsdir\webview2bootstrapper"
|
||||
File "MicrosoftEdgeWebview2Setup.exe"
|
||||
ExecWait '"$pluginsdir\webview2bootstrapper\MicrosoftEdgeWebview2Setup.exe" /silent /install'
|
||||
|
||||
SetDetailsPrint both
|
||||
webview2_ok:
|
||||
!macroend
|
||||
|
||||
Section -WebView2
|
||||
!insertmacro nb.webview2runtime
|
||||
SectionEnd
|
||||
|
||||
Section -Post
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install'
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service start'
|
||||
@@ -299,11 +344,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
DetailPrint "Terminating Netbird UI process..."
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart registry entry
|
||||
; Remove autostart entries from every view a previous installer may have used.
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
DetailPrint "Checking if user requested data deletion..."
|
||||
@@ -326,9 +376,9 @@ DetailPrint "Deleting application files..."
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
!if ${ARCH} == "amd64"
|
||||
# Legacy: pre-Wails installs shipped opengl32.dll (Mesa3D for Fyne); remove
|
||||
# any leftover copy on uninstall so old upgrades don't leave it behind.
|
||||
Delete "$INSTDIR\opengl32.dll"
|
||||
!endif
|
||||
DetailPrint "Removing application directory..."
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +22,25 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// peerLoginExpiredMsg is the exact phrase the management server returns
|
||||
// when a previously SSO-enrolled peer's login has expired. Sourced from
|
||||
// shared/management/status/error.go (NewPeerLoginExpiredError). Matched
|
||||
// by substring so a future server-side rewording that keeps the phrase
|
||||
// still triggers the friendly fallback in Login().
|
||||
const peerLoginExpiredMsg = "peer login has expired"
|
||||
|
||||
// errSetupKeyOnSSOExpiredPeer replaces the raw management error when the
|
||||
// user runs `netbird login -k <setup-key>` against a peer that was
|
||||
// originally enrolled via SSO. Wrapped in a PermissionDenied gRPC status
|
||||
// so callers' existing isPermissionDenied / isAuthError checks still
|
||||
// classify it correctly (early-exit from retry backoff, StatusNeedsLogin
|
||||
// in the server state machine).
|
||||
var errSetupKeyOnSSOExpiredPeer = status.Error(
|
||||
codes.PermissionDenied,
|
||||
"this peer was originally enrolled via SSO and its session has expired. "+
|
||||
"Setup keys can only enrol new peers — run `netbird up` (interactive SSO) to re-login.",
|
||||
)
|
||||
|
||||
// Auth manages authentication operations with the management server
|
||||
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||
type Auth struct {
|
||||
@@ -184,6 +204,15 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
// The peer pub-key is already on file with the management
|
||||
// server (originally enrolled via SSO) and the session has
|
||||
// expired. The setup-key path can only enrol new peers, so
|
||||
// retrying with -k will keep failing. Replace the raw mgm
|
||||
// message with an actionable hint that tells the user to
|
||||
// re-authenticate via SSO instead.
|
||||
if setupKey != "" && jwtToken == "" && isPeerLoginExpired(err) {
|
||||
err = errSetupKeyOnSSOExpiredPeer
|
||||
}
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
@@ -474,3 +503,16 @@ func isLoginNeeded(err error) bool {
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
return isPermissionDenied(err)
|
||||
}
|
||||
|
||||
// isPeerLoginExpired reports whether err is the management server's
|
||||
// "peer login has expired" PermissionDenied response. Used by Login to
|
||||
// detect the case where the caller passed a setup-key but the peer is
|
||||
// actually an SSO-enrolled record whose session needs refreshing — the
|
||||
// setup-key path cannot help there.
|
||||
func isPeerLoginExpired(err error) bool {
|
||||
if !isPermissionDenied(err) {
|
||||
return false
|
||||
}
|
||||
s, _ := status.FromError(err)
|
||||
return strings.Contains(s.Message(), peerLoginExpiredMsg)
|
||||
}
|
||||
|
||||
80
client/internal/auth/auth_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestIsPeerLoginExpired(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "plain error (not a gRPC status)",
|
||||
err: errors.New("network read: connection reset"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "PermissionDenied with different message",
|
||||
err: status.Error(codes.PermissionDenied, "user is blocked"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated with the expected phrase",
|
||||
// Wrong status code — must still return false.
|
||||
err: status.Error(codes.Unauthenticated, "peer login has expired, please log in once more"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "exact server message",
|
||||
err: status.Error(codes.PermissionDenied, "peer login has expired, please log in once more"),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "phrase as substring",
|
||||
// Future-proofing: if mgm reworords but keeps the phrase,
|
||||
// the friendly fallback must still kick in.
|
||||
err: status.Error(codes.PermissionDenied, "session refused: peer login has expired (account=foo)"),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := isPeerLoginExpired(tc.err); got != tc.want {
|
||||
t.Fatalf("isPeerLoginExpired(%v) = %v, want %v", tc.err, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrSetupKeyOnSSOExpiredPeer(t *testing.T) {
|
||||
// Sentinel must surface as PermissionDenied so the upstream
|
||||
// isPermissionDenied / isAuthError checks classify it correctly
|
||||
// (short-circuit retry backoff, set StatusNeedsLogin).
|
||||
if !isPermissionDenied(errSetupKeyOnSSOExpiredPeer) {
|
||||
t.Fatalf("errSetupKeyOnSSOExpiredPeer must be a PermissionDenied gRPC error")
|
||||
}
|
||||
|
||||
// Message must actually mention SSO and `netbird up` so it is
|
||||
// actionable for the end user. Loose substring checks keep the
|
||||
// test resilient to copy edits.
|
||||
s, _ := status.FromError(errSetupKeyOnSSOExpiredPeer)
|
||||
msg := strings.ToLower(s.Message())
|
||||
for _, want := range []string{"sso", "netbird up"} {
|
||||
if !strings.Contains(msg, want) {
|
||||
t.Errorf("sentinel message should contain %q, got %q", want, s.Message())
|
||||
}
|
||||
}
|
||||
}
|
||||
89
client/internal/auth/pending_flow.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PendingFlow stores an in-progress OAuth flow between the RPC that
|
||||
// initiates it (returns the verification URI to the UI) and the RPC
|
||||
// that waits for the user to complete it. The flow handle, the
|
||||
// device-code info, and the absolute expiry are kept together so the
|
||||
// waiting RPC can validate the device code and reuse the same flow.
|
||||
//
|
||||
// PendingFlow is safe for concurrent use; callers must not access the
|
||||
// stored fields directly.
|
||||
type PendingFlow struct {
|
||||
mu sync.Mutex
|
||||
flow OAuthFlow
|
||||
info AuthFlowInfo
|
||||
expiresAt time.Time
|
||||
waitCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewPendingFlow returns an empty PendingFlow ready to be populated by Set.
|
||||
func NewPendingFlow() *PendingFlow {
|
||||
return &PendingFlow{}
|
||||
}
|
||||
|
||||
// Set stores the flow and its authorization info, computing the absolute
|
||||
// expiry from info.ExpiresIn (seconds, as returned by the IdP).
|
||||
func (p *PendingFlow) Set(flow OAuthFlow, info AuthFlowInfo) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.flow = flow
|
||||
p.info = info
|
||||
p.expiresAt = time.Now().Add(time.Duration(info.ExpiresIn) * time.Second)
|
||||
}
|
||||
|
||||
// Get returns the stored flow, info, and whether a flow is currently
|
||||
// pending. Returns (nil, zero, false) after Clear or before Set.
|
||||
func (p *PendingFlow) Get() (OAuthFlow, AuthFlowInfo, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.flow == nil {
|
||||
return nil, AuthFlowInfo{}, false
|
||||
}
|
||||
return p.flow, p.info, true
|
||||
}
|
||||
|
||||
// ExpiresAt returns the absolute expiry of the pending flow. Returns
|
||||
// the zero time when no flow is pending.
|
||||
func (p *PendingFlow) ExpiresAt() time.Time {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.expiresAt
|
||||
}
|
||||
|
||||
// SetWaitCancel records the cancel function for the goroutine currently
|
||||
// blocked in WaitToken so a new RequestAuth can preempt it.
|
||||
func (p *PendingFlow) SetWaitCancel(cancel context.CancelFunc) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.waitCancel = cancel
|
||||
}
|
||||
|
||||
// CancelWait invokes and clears the stored wait-cancel, if any. Safe to
|
||||
// call when no wait is in progress.
|
||||
func (p *PendingFlow) CancelWait() {
|
||||
p.mu.Lock()
|
||||
cancel := p.waitCancel
|
||||
p.waitCancel = nil
|
||||
p.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Clear resets the pending flow to empty. Any stored wait-cancel is
|
||||
// dropped without being invoked — call CancelWait first if the waiting
|
||||
// goroutine must be stopped.
|
||||
func (p *PendingFlow) Clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.flow = nil
|
||||
p.info = AuthFlowInfo{}
|
||||
p.expiresAt = time.Time{}
|
||||
p.waitCancel = nil
|
||||
}
|
||||
74
client/internal/auth/sessionwatch/event.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package sessionwatch
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// internal event kinds are no longer exposed: the watcher drives the Sink
|
||||
// directly (NotifyStateChange on deadline change/clear, PublishEvent at
|
||||
// each warning lead). Tests use a mock Sink to observe what the watcher
|
||||
// emits.
|
||||
|
||||
// Metadata keys attached by the daemon to session-warning SystemEvents.
|
||||
// The UI tray reads these to build a locale-aware notification without
|
||||
// relying on the daemon's locale-less UserMessage string, and to
|
||||
// disambiguate the T-WarningLead notification from the T-FinalWarningLead
|
||||
// fallback that auto-opens the SessionAboutToExpire dialog.
|
||||
const (
|
||||
// MetaSessionWarning is set to "true" on both warning events (T-10 and
|
||||
// T-2) so the UI can detect a session-warning SystemEvent without
|
||||
// matching on the message text. Use MetaSessionFinal to distinguish
|
||||
// the two.
|
||||
MetaSessionWarning = "session_warning"
|
||||
// MetaSessionFinal is set to "true" on the T-FinalWarningLead event
|
||||
// only. Consumers that need to auto-open the SessionAboutToExpire
|
||||
// dialog gate on this; T-WarningLead events leave the field unset.
|
||||
MetaSessionFinal = "session_final_warning"
|
||||
// MetaSessionExpiresAt carries the absolute UTC deadline encoded with
|
||||
// FormatExpiresAt; consumers must decode with ParseExpiresAt so a
|
||||
// future format change stays a single edit.
|
||||
MetaSessionExpiresAt = "session_expires_at"
|
||||
// MetaSessionLeadMinutes carries the lead in whole minutes (WarningLead
|
||||
// for the T-10 event, FinalWarningLead for the T-2 event) so the UI
|
||||
// can show "expires in ~N minutes" without hardcoding either constant.
|
||||
MetaSessionLeadMinutes = "lead_minutes"
|
||||
)
|
||||
|
||||
// expiresAtLayout is the wire format used for MetaSessionExpiresAt.
|
||||
// Producer and consumers both go through FormatExpiresAt/ParseExpiresAt
|
||||
// so this layout stays a single source of truth.
|
||||
const expiresAtLayout = time.RFC3339
|
||||
|
||||
// FormatExpiresAt encodes a deadline for MetaSessionExpiresAt. Always
|
||||
// emits UTC so a consumer in another timezone reads the same wall-clock
|
||||
// deadline.
|
||||
func FormatExpiresAt(t time.Time) string {
|
||||
return t.UTC().Format(expiresAtLayout)
|
||||
}
|
||||
|
||||
// ParseExpiresAt decodes the MetaSessionExpiresAt value back to a UTC
|
||||
// time. Returns an error when the field is empty or malformed; the
|
||||
// caller decides whether to fall back (zero value) or propagate.
|
||||
func ParseExpiresAt(s string) (time.Time, error) {
|
||||
t, err := time.Parse(expiresAtLayout, s)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return t.UTC(), nil
|
||||
}
|
||||
|
||||
// FormatLeadMinutes encodes a lead duration for MetaSessionLeadMinutes
|
||||
// as the integer count of whole minutes. Sub-minute residuals are
|
||||
// truncated — the field is informational ("expires in ~N minutes") and
|
||||
// fractional minutes don't change what the UI displays.
|
||||
func FormatLeadMinutes(d time.Duration) string {
|
||||
return strconv.Itoa(int(d / time.Minute))
|
||||
}
|
||||
|
||||
// ParseLeadMinutes decodes a MetaSessionLeadMinutes value. Returns 0
|
||||
// and the parse error for malformed input; consumers that prefer a
|
||||
// silent fallback can simply ignore the error.
|
||||
func ParseLeadMinutes(s string) (int, error) {
|
||||
return strconv.Atoi(s)
|
||||
}
|
||||
362
client/internal/auth/sessionwatch/watcher.go
Normal file
@@ -0,0 +1,362 @@
|
||||
// Package sessionwatch tracks the SSO session expiry deadline that the
|
||||
// management server publishes via LoginResponse / SyncResponse and fires
|
||||
// two warning events at fixed lead times before expiry: an interactive
|
||||
// T-WarningLead notification and a dismiss-gated T-FinalWarningLead
|
||||
// fallback dialog.
|
||||
//
|
||||
// The watcher is idempotent: Update may be called as often as the network
|
||||
// map snapshots arrive. Repeating the same deadline is a no-op; a new
|
||||
// deadline reschedules the timers and arms a fresh warning cycle.
|
||||
//
|
||||
// Warning firing is edge-detected. Each unique deadline value fires each
|
||||
// warning callback at most once.
|
||||
package sessionwatch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
// Skew tolerates a small clock difference between the management
|
||||
// server and this peer before treating a deadline as "in the past".
|
||||
// Slightly above typical NTP drift; tight enough that the UI doesn't
|
||||
// paint a stale expiry as if it were valid.
|
||||
Skew = 30 * time.Second
|
||||
|
||||
// maxDeadlineHorizon caps how far in the future an accepted deadline
|
||||
// can sit. A timestamp beyond this is almost certainly a protocol
|
||||
// glitch, and silently arming a 100-year timer would hide the bug.
|
||||
maxDeadlineHorizon = 10 * 365 * 24 * time.Hour
|
||||
|
||||
// WarningLead is how far before expiry the first (interactive)
|
||||
// warning fires. Drives the T-10 OS notification with
|
||||
// Extend/Dismiss actions.
|
||||
WarningLead = 10 * time.Minute
|
||||
|
||||
// FinalWarningLead is how far before expiry the fallback final
|
||||
// warning fires. Drives the auto-opened SessionAboutToExpire dialog,
|
||||
// but only when the user has not dismissed the T-WarningLead warning
|
||||
// for the same deadline. Must be strictly less than WarningLead.
|
||||
FinalWarningLead = 2 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDeadlineBeforeEpoch is returned by Update when the supplied
|
||||
// deadline pre-dates 1970-01-01.
|
||||
ErrDeadlineBeforeEpoch = errors.New("session deadline before unix epoch")
|
||||
|
||||
// ErrDeadlineTooFarFuture is returned by Update when the supplied
|
||||
// deadline is more than maxDeadlineHorizon in the future.
|
||||
ErrDeadlineTooFarFuture = errors.New("session deadline too far in the future")
|
||||
|
||||
// ErrDeadlineInPast is returned by Update when the supplied deadline
|
||||
// is more than Skew in the past.
|
||||
ErrDeadlineInPast = errors.New("session deadline in the past")
|
||||
)
|
||||
|
||||
// StatusRecorder is the side-effect surface the watcher drives on every
|
||||
// state transition. Production wires this to peer.Status (NotifyStateChange
|
||||
// for deadline change/clear, PublishEvent for the two warnings); tests pass
|
||||
// a fake recorder so the same surface is observable without an engine.
|
||||
//
|
||||
// PublishEvent's signature mirrors peer.Status.PublishEvent: the watcher
|
||||
// composes the metadata internally so the wire format (MetaSession*) is
|
||||
// owned by sessionwatch, not the caller.
|
||||
type StatusRecorder interface {
|
||||
NotifyStateChange()
|
||||
PublishEvent(
|
||||
severity cProto.SystemEvent_Severity,
|
||||
category cProto.SystemEvent_Category,
|
||||
message string,
|
||||
userMessage string,
|
||||
metadata map[string]string,
|
||||
)
|
||||
}
|
||||
|
||||
// Watcher observes the latest session deadline and fires two warnings
|
||||
// before it expires: the interactive T-WarningLead notification, and the
|
||||
// fallback T-FinalWarningLead dialog (suppressed when the user dismissed
|
||||
// the first one for the same deadline). Safe for concurrent use.
|
||||
type Watcher struct {
|
||||
lead time.Duration
|
||||
finalLead time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
current time.Time
|
||||
timer *time.Timer
|
||||
finalTimer *time.Timer
|
||||
firedAt time.Time // deadline value the T-WarningLead callback last fired against
|
||||
finalFiredAt time.Time // deadline value the T-FinalWarningLead callback last fired against
|
||||
dismissedAt time.Time // deadline value the user dismissed via Dismiss(); gates fireFinal
|
||||
closed bool
|
||||
recorder StatusRecorder
|
||||
}
|
||||
|
||||
// New returns a watcher with the package defaults WarningLead and
|
||||
// FinalWarningLead. Pass nil for recorder to silence side effects (handy
|
||||
// in unit tests that exercise sanity checks without observing the publish
|
||||
// path).
|
||||
func New(recorder StatusRecorder) *Watcher {
|
||||
return NewWithLeads(WarningLead, FinalWarningLead, recorder)
|
||||
}
|
||||
|
||||
// NewWithLeads returns a watcher with custom lead times. Useful for tests.
|
||||
// final must be strictly less than lead; otherwise both timers fire in the
|
||||
// wrong order or simultaneously and the UI flow breaks. A zero final lead
|
||||
// disables the final-warning timer entirely (see armTimerLocked) so a
|
||||
// millisecond-scale deadline doesn't flush both timers in one tick.
|
||||
func NewWithLeads(lead, final time.Duration, recorder StatusRecorder) *Watcher {
|
||||
return &Watcher{
|
||||
lead: lead,
|
||||
finalLead: final,
|
||||
recorder: recorder,
|
||||
}
|
||||
}
|
||||
|
||||
// Update sets the latest deadline. Pass the zero time to clear (e.g. when
|
||||
// a Sync push from the server omits the field because login expiration
|
||||
// was disabled).
|
||||
//
|
||||
// Same-value updates are no-ops. A different non-zero value cancels any
|
||||
// pending timer, resets the "already fired" guard, and arms a new one.
|
||||
//
|
||||
// Returns one of the sentinel Err* values when the deadline fails the
|
||||
// sanity checks (pre-epoch, far future, or in the past beyond Skew).
|
||||
// In every error case the watcher first clears its state so it stays
|
||||
// consistent with what the caller will push into its other sinks (e.g.
|
||||
// applySessionDeadline forces a zero deadline into the status recorder
|
||||
// after a non-nil error).
|
||||
func (w *Watcher) Update(deadline time.Time) error {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
if deadline.IsZero() {
|
||||
w.clearLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
switch {
|
||||
case deadline.Before(time.Unix(0, 0)):
|
||||
w.clearLocked()
|
||||
return fmt.Errorf("%w: %v", ErrDeadlineBeforeEpoch, deadline)
|
||||
case deadline.After(now.Add(maxDeadlineHorizon)):
|
||||
w.clearLocked()
|
||||
return fmt.Errorf("%w: %v", ErrDeadlineTooFarFuture, deadline)
|
||||
case deadline.Before(now.Add(-Skew)):
|
||||
w.clearLocked()
|
||||
return fmt.Errorf("%w: %v (now=%v)", ErrDeadlineInPast, deadline, now)
|
||||
}
|
||||
|
||||
if deadline.Equal(w.current) {
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
w.stopTimerLocked()
|
||||
w.current = deadline
|
||||
// Reset every per-deadline guard so a refreshed deadline arms a fresh
|
||||
// warning cycle: both edge triggers and the user Dismiss decision
|
||||
// (the user agreed to the old deadline expiring; a new deadline
|
||||
// restarts the contract).
|
||||
w.firedAt = time.Time{}
|
||||
w.finalFiredAt = time.Time{}
|
||||
w.dismissedAt = time.Time{}
|
||||
|
||||
w.armTimerLocked(deadline)
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder != nil {
|
||||
recorder.NotifyStateChange()
|
||||
}
|
||||
log.Infof("auth session deadline set to: %s (in %s)", deadline.Format(time.RFC3339), time.Until(deadline).Round(time.Second))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deadline returns the most recently observed deadline. Zero when no
|
||||
// deadline is currently tracked.
|
||||
func (w *Watcher) Deadline() time.Time {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.current
|
||||
}
|
||||
|
||||
// Dismiss records the user's "Dismiss" action against the current deadline
|
||||
// and suppresses the upcoming final-warning callback for that deadline.
|
||||
// Idempotent: repeated calls are no-ops. A subsequent Update with a fresh
|
||||
// deadline resets the dismissal so the final-warning cycle re-arms.
|
||||
//
|
||||
// No-op when the watcher holds no deadline or has been closed.
|
||||
func (w *Watcher) Dismiss() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.closed || w.current.IsZero() {
|
||||
return
|
||||
}
|
||||
if w.dismissedAt.Equal(w.current) {
|
||||
return
|
||||
}
|
||||
w.dismissedAt = w.current
|
||||
// Cancel the armed final-warning timer eagerly. fireFinal would also
|
||||
// gate on dismissedAt, but stopping the timer avoids a wakeup with
|
||||
// nothing to do and makes the intent visible.
|
||||
if w.finalTimer != nil {
|
||||
w.finalTimer.Stop()
|
||||
w.finalTimer = nil
|
||||
}
|
||||
log.Infof("auth session final-warning dismissed for deadline %s", w.current.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
// Close stops any pending timer. Update calls after Close are ignored.
|
||||
func (w *Watcher) Close() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
w.closed = true
|
||||
w.stopTimerLocked()
|
||||
}
|
||||
|
||||
// clearLocked drops the tracked deadline and notifies the recorder so
|
||||
// downstream consumers (SubscribeStatus stream, UI) drop their anchor.
|
||||
// The caller must hold w.mu; this helper releases it before invoking
|
||||
// the recorder.
|
||||
func (w *Watcher) clearLocked() {
|
||||
if w.current.IsZero() {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.stopTimerLocked()
|
||||
w.current = time.Time{}
|
||||
w.firedAt = time.Time{}
|
||||
w.finalFiredAt = time.Time{}
|
||||
w.dismissedAt = time.Time{}
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder != nil {
|
||||
recorder.NotifyStateChange()
|
||||
}
|
||||
log.Infof("auth session deadline cleared")
|
||||
}
|
||||
|
||||
func (w *Watcher) stopTimerLocked() {
|
||||
if w.timer != nil {
|
||||
w.timer.Stop()
|
||||
w.timer = nil
|
||||
}
|
||||
if w.finalTimer != nil {
|
||||
w.finalTimer.Stop()
|
||||
w.finalTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) armTimerLocked(deadline time.Time) {
|
||||
w.timer = armOneShotLocked(deadline.Add(-w.lead), func() { w.fire(deadline) })
|
||||
// finalLead <= 0 disables the final-warning timer entirely. Used by
|
||||
// tests that predate the final-warning fallback so a millisecond-scale
|
||||
// deadline does not flush both timers at once.
|
||||
if w.finalLead > 0 {
|
||||
w.finalTimer = armOneShotLocked(deadline.Add(-w.finalLead), func() { w.fireFinal(deadline) })
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) fire(armedFor time.Time) {
|
||||
w.mu.Lock()
|
||||
if w.closed || !w.current.Equal(armedFor) {
|
||||
// Deadline moved while we were waiting (e.g. a successful extend).
|
||||
// The reschedule path armed a fresh timer; this one is stale.
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if !w.firedAt.IsZero() && w.firedAt.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.firedAt = armedFor
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder == nil {
|
||||
return
|
||||
}
|
||||
log.Infof("auth session expiry soon warning fired")
|
||||
publishWarning(recorder, armedFor, false)
|
||||
}
|
||||
|
||||
// fireFinal mirrors fire for the T-FinalWarningLead timer with an extra
|
||||
// dismiss-gate: if the user dismissed the T-WarningLead notification for
|
||||
// this deadline, the final warning is suppressed entirely.
|
||||
func (w *Watcher) fireFinal(armedFor time.Time) {
|
||||
w.mu.Lock()
|
||||
if w.closed || !w.current.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if !w.finalFiredAt.IsZero() && w.finalFiredAt.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if w.dismissedAt.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
log.Infof("auth session final-warning skipped (dismissed by user)")
|
||||
return
|
||||
}
|
||||
w.finalFiredAt = armedFor
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder == nil {
|
||||
return
|
||||
}
|
||||
log.Infof("auth session final-warning fired")
|
||||
publishWarning(recorder, armedFor, true)
|
||||
}
|
||||
|
||||
// armOneShotLocked schedules cb at fireAt. When fireAt is already in the
|
||||
// past it dispatches on the next scheduler tick so a state-change recorder
|
||||
// notification (invoked after w.mu is released) lands first. Caller must
|
||||
// hold w.mu.
|
||||
func armOneShotLocked(fireAt time.Time, cb func()) *time.Timer {
|
||||
delay := time.Until(fireAt)
|
||||
if delay <= 0 {
|
||||
return time.AfterFunc(0, cb)
|
||||
}
|
||||
return time.AfterFunc(delay, cb)
|
||||
}
|
||||
|
||||
// publishWarning composes the SystemEvent for a watcher-fired warning and
|
||||
// pushes it through the recorder. Severity is CRITICAL on both — bypassing
|
||||
// the user's Notifications toggle is deliberate: missing the warning
|
||||
// window forces the post-mortem SessionExpired flow (tunnel torn down,
|
||||
// lock icon, manual re-login), which is the UX we are trying to avoid.
|
||||
func publishWarning(recorder StatusRecorder, deadline time.Time, final bool) {
|
||||
lead := WarningLead
|
||||
message := "session expiry warning"
|
||||
meta := map[string]string{
|
||||
MetaSessionWarning: "true",
|
||||
MetaSessionExpiresAt: FormatExpiresAt(deadline),
|
||||
}
|
||||
if final {
|
||||
lead = FinalWarningLead
|
||||
message = "session expiry final warning"
|
||||
meta[MetaSessionFinal] = "true"
|
||||
}
|
||||
meta[MetaSessionLeadMinutes] = FormatLeadMinutes(lead)
|
||||
|
||||
recorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_AUTHENTICATION,
|
||||
message,
|
||||
"",
|
||||
meta,
|
||||
)
|
||||
}
|
||||
463
client/internal/auth/sessionwatch/watcher_test.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package sessionwatch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// fakeRecorder satisfies StatusRecorder and records every call so tests
|
||||
// can observe what the watcher emits. NotifyStateChange and PublishEvent
|
||||
// land in the same ordered events slice (with the Kind distinguishing
|
||||
// them) so tests that care about ordering still work.
|
||||
type fakeRecorder struct {
|
||||
mu sync.Mutex
|
||||
events []event
|
||||
}
|
||||
|
||||
type eventKind int
|
||||
|
||||
const (
|
||||
stateChange eventKind = iota
|
||||
publish
|
||||
)
|
||||
|
||||
type event struct {
|
||||
kind eventKind
|
||||
// Set only for publish events.
|
||||
severity cProto.SystemEvent_Severity
|
||||
category cProto.SystemEvent_Category
|
||||
message string
|
||||
meta map[string]string
|
||||
}
|
||||
|
||||
func (r *fakeRecorder) NotifyStateChange() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.events = append(r.events, event{kind: stateChange})
|
||||
}
|
||||
|
||||
func (r *fakeRecorder) PublishEvent(
|
||||
severity cProto.SystemEvent_Severity,
|
||||
category cProto.SystemEvent_Category,
|
||||
message string,
|
||||
_ string,
|
||||
metadata map[string]string,
|
||||
) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.events = append(r.events, event{
|
||||
kind: publish,
|
||||
severity: severity,
|
||||
category: category,
|
||||
message: message,
|
||||
meta: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *fakeRecorder) snapshot() []event {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([]event, len(r.events))
|
||||
copy(out, r.events)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e event) isFinalWarning() bool {
|
||||
return e.kind == publish && e.meta[MetaSessionFinal] == "true"
|
||||
}
|
||||
|
||||
func (e event) isWarning() bool {
|
||||
return e.kind == publish && e.meta[MetaSessionWarning] == "true" && e.meta[MetaSessionFinal] != "true"
|
||||
}
|
||||
|
||||
func countWhere(events []event, pred func(event) bool) int {
|
||||
n := 0
|
||||
for _, e := range events {
|
||||
if pred(e) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func waitForEvents(t *testing.T, r *fakeRecorder, want int) []event {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if got := r.snapshot(); len(got) >= want {
|
||||
return got
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
got := r.snapshot()
|
||||
t.Fatalf("timed out waiting for %d events, got %d: %+v", want, len(got), got)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newWatcher builds a watcher with the final timer disabled (finalLead=0),
|
||||
// matching the lead-only behaviour the pre-final-warning tests assume.
|
||||
func newWatcher(lead time.Duration, r *fakeRecorder) *Watcher {
|
||||
return NewWithLeads(lead, 0, r)
|
||||
}
|
||||
|
||||
func TestUpdateZeroBeforeAnythingIsNoop(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
_ = w.Update(time.Time{})
|
||||
|
||||
if got := r.snapshot(); len(got) != 0 {
|
||||
t.Fatalf("expected no events on initial zero, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNonZeroFiresStateChange(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(time.Hour)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 1)
|
||||
if events[0].kind != stateChange {
|
||||
t.Fatalf("expected stateChange, got %+v", events[0])
|
||||
}
|
||||
if !w.Deadline().Equal(d) {
|
||||
t.Fatalf("deadline mismatch: %v vs %v", w.Deadline(), d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSameDeadlineIsNoop(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(time.Hour)
|
||||
_ = w.Update(d)
|
||||
_ = w.Update(d)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 1)
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected exactly 1 event for repeated same deadline, got %d: %+v", len(events), events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarningFiresOnceWithinLeadWindow(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
lead := 50 * time.Millisecond
|
||||
w := newWatcher(lead, r)
|
||||
defer w.Close()
|
||||
|
||||
// Deadline 80ms out — warning should fire after ~30ms.
|
||||
d := time.Now().Add(80 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 2)
|
||||
if events[0].kind != stateChange {
|
||||
t.Fatalf("event[0] should be stateChange, got %+v", events[0])
|
||||
}
|
||||
if !events[1].isWarning() {
|
||||
t.Fatalf("event[1] should be a warning publish, got %+v", events[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarningFiresImmediatelyWhenAlreadyInsideWindow(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r) // lead > delta => fire immediately
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(10 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 2)
|
||||
if !events[1].isWarning() {
|
||||
t.Fatalf("expected immediate warning publish, got %+v", events[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDeadlineCancelsPriorTimer(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
lead := 50 * time.Millisecond
|
||||
w := newWatcher(lead, r)
|
||||
defer w.Close()
|
||||
|
||||
first := time.Now().Add(80 * time.Millisecond) // would fire warning ~30ms in
|
||||
_ = w.Update(first)
|
||||
|
||||
// Replace with a far-future deadline before the warning fires.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
second := time.Now().Add(time.Hour)
|
||||
_ = w.Update(second)
|
||||
|
||||
// Wait past when first's warning would have fired.
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
|
||||
if n := countWhere(r.snapshot(), event.isWarning); n != 0 {
|
||||
t.Fatalf("warning fired for cancelled deadline: %+v", r.snapshot())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAfterFireArmsNewWarning(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
lead := 30 * time.Millisecond
|
||||
w := newWatcher(lead, r)
|
||||
defer w.Close()
|
||||
|
||||
first := time.Now().Add(50 * time.Millisecond)
|
||||
_ = w.Update(first)
|
||||
|
||||
// Wait for stateChange + warning of the first cycle.
|
||||
waitForEvents(t, r, 2)
|
||||
|
||||
// Simulate a successful extend: brand new deadline.
|
||||
second := time.Now().Add(60 * time.Millisecond)
|
||||
_ = w.Update(second)
|
||||
|
||||
// 4 events total: stateChange, warning (first), stateChange, warning (second).
|
||||
events := waitForEvents(t, r, 4)
|
||||
if events[2].kind != stateChange {
|
||||
t.Fatalf("event[2] should be stateChange for the new deadline, got %+v", events[2])
|
||||
}
|
||||
if !events[3].isWarning() {
|
||||
t.Fatalf("event[3] should be a warning publish for the new deadline, got %+v", events[3])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateZeroAfterNonZeroClearsState(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(2 * time.Hour)
|
||||
_ = w.Update(d)
|
||||
waitForEvents(t, r, 1)
|
||||
|
||||
_ = w.Update(time.Time{})
|
||||
|
||||
events := waitForEvents(t, r, 2)
|
||||
if events[1].kind != stateChange {
|
||||
t.Fatalf("expected stateChange on clear, got %+v", events[1])
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("Deadline should be zero after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRejectsBeforeEpoch(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
good := time.Now().Add(time.Hour)
|
||||
if err := w.Update(good); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
|
||||
err := w.Update(time.Unix(-100, 0))
|
||||
if !errors.Is(err, ErrDeadlineBeforeEpoch) {
|
||||
t.Fatalf("want ErrDeadlineBeforeEpoch, got %v", err)
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("rejected pre-epoch update must clear deadline; got %v", w.Deadline())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRejectsTooFarFuture(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
good := time.Now().Add(time.Hour)
|
||||
if err := w.Update(good); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
|
||||
err := w.Update(time.Now().Add(50 * 365 * 24 * time.Hour))
|
||||
if !errors.Is(err, ErrDeadlineTooFarFuture) {
|
||||
t.Fatalf("want ErrDeadlineTooFarFuture, got %v", err)
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("rejected far-future update must clear deadline; got %v", w.Deadline())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInPastClearsDeadline(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
good := time.Now().Add(time.Hour)
|
||||
if err := w.Update(good); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
// Drain the stateChange from the seed.
|
||||
waitForEvents(t, r, 1)
|
||||
|
||||
err := w.Update(time.Now().Add(-1 * time.Hour))
|
||||
if !errors.Is(err, ErrDeadlineInPast) {
|
||||
t.Fatalf("want ErrDeadlineInPast, got %v", err)
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("in-past update must clear the deadline, got %v", w.Deadline())
|
||||
}
|
||||
events := waitForEvents(t, r, 2)
|
||||
if events[1].kind != stateChange {
|
||||
t.Fatalf("expected stateChange on clear, got %+v", events[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWithinSkewAccepted(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
// 5 seconds in the past is within the 30s Skew tolerance — accept it.
|
||||
d := time.Now().Add(-5 * time.Second)
|
||||
if err := w.Update(d); err != nil {
|
||||
t.Fatalf("within-skew Update should succeed, got %v", err)
|
||||
}
|
||||
if !w.Deadline().Equal(d) {
|
||||
t.Fatalf("expected deadline to be applied, got %v want %v", w.Deadline(), d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseSilencesUpdates(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
w.Close()
|
||||
|
||||
_ = w.Update(time.Now().Add(time.Hour))
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if got := r.snapshot(); len(got) != 0 {
|
||||
t.Fatalf("expected no events after Close, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalWarningFiresAfterRegularWarning(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
// Warning fires at deadline-80ms, final at deadline-30ms.
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
// Expect stateChange + warning + final-warning.
|
||||
events := waitForEvents(t, r, 3)
|
||||
|
||||
if countWhere(events, func(e event) bool { return e.kind == stateChange }) != 1 {
|
||||
t.Fatalf("expected exactly 1 stateChange, got %+v", events)
|
||||
}
|
||||
if countWhere(events, event.isWarning) != 1 {
|
||||
t.Fatalf("expected exactly 1 warning publish, got %+v", events)
|
||||
}
|
||||
if countWhere(events, event.isFinalWarning) != 1 {
|
||||
t.Fatalf("expected exactly 1 final-warning publish, got %+v", events)
|
||||
}
|
||||
|
||||
// Warning must precede final (same deadline, longer lead fires first).
|
||||
var wIdx, fIdx int
|
||||
for i, e := range events {
|
||||
switch {
|
||||
case e.isWarning():
|
||||
wIdx = i
|
||||
case e.isFinalWarning():
|
||||
fIdx = i
|
||||
}
|
||||
}
|
||||
if wIdx > fIdx {
|
||||
t.Fatalf("warning must publish before final-warning, got order %+v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDismissSuppressesFinalWarning(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
// Wait for the warning publish so we know we're inside the warning
|
||||
// window, then dismiss before the final timer would fire.
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if countWhere(r.snapshot(), event.isWarning) >= 1 {
|
||||
break
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
if countWhere(r.snapshot(), event.isWarning) < 1 {
|
||||
t.Fatalf("warning did not publish in time, events=%+v", r.snapshot())
|
||||
}
|
||||
|
||||
w.Dismiss()
|
||||
|
||||
// Now wait past when the final would have fired.
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
if n := countWhere(r.snapshot(), event.isFinalWarning); n != 0 {
|
||||
t.Fatalf("final-warning published after Dismiss(), events=%+v", r.snapshot())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDismissResetByNewDeadline(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
first := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(first)
|
||||
|
||||
// Dismiss against the first deadline.
|
||||
w.Dismiss()
|
||||
|
||||
// Replace with a fresh deadline before the first's timers complete.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
second := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(second)
|
||||
|
||||
// The second cycle must publish a final-warning (the dismiss state
|
||||
// did not carry over).
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if countWhere(r.snapshot(), event.isFinalWarning) >= 1 {
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
if countWhere(r.snapshot(), event.isFinalWarning) < 1 {
|
||||
t.Fatalf("final-warning did not publish on fresh deadline after Dismiss reset, events=%+v", r.snapshot())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDismissBeforeUpdateIsNoop(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
// No deadline tracked yet; Dismiss must be a no-op (no panic, no state).
|
||||
w.Dismiss()
|
||||
|
||||
d := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
// Final warning should still publish — Dismiss only acts on the current
|
||||
// deadline, and there was none at the time of the call.
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if countWhere(r.snapshot(), event.isFinalWarning) >= 1 {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("final-warning did not publish after no-op pre-Update Dismiss, events=%+v", r.snapshot())
|
||||
}
|
||||
@@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
@@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
@@ -258,6 +256,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
// On daemon shutdown / Down() the parent context is cancelled
|
||||
// and the dial fails with "context canceled". Wrapping that
|
||||
// into state would leave the snapshot stuck at Connecting+err
|
||||
// until the backoff loop wakes up — instead let the operation
|
||||
// return cleanly so the deferred state.Set(StatusIdle) takes
|
||||
// effect on the next iteration.
|
||||
if c.ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
||||
@@ -386,6 +393,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
// Seed the session-expiry deadline from the LoginResponse. Subsequent
|
||||
// changes flow in through SyncResponse and are applied in handleSync.
|
||||
engine.ApplySessionDeadline(loginResp.GetSessionExpiresAt())
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
@@ -426,7 +437,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
c.statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backOff)
|
||||
// Wrap the backoff with c.ctx so Down()/actCancel propagates into the
|
||||
// inter-attempt sleep — otherwise a 15s MaxInterval can keep the retry
|
||||
// loop alive long after the caller asked to give up, leaving the
|
||||
// status stream stuck at Connecting.
|
||||
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
|
||||
@@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
|
||||
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
|
||||
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
|
||||
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
|
||||
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
|
||||
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
@@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird
|
||||
Other non-sensitive configuration options are included without anonymization.
|
||||
|
||||
Firewall Rules (Linux only)
|
||||
The bundle includes two separate firewall rule files:
|
||||
The bundle includes the following firewall-related files:
|
||||
|
||||
iptables.txt:
|
||||
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
|
||||
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
|
||||
- Includes all tables (filter, nat, mangle, raw, security)
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
|
||||
ip6tables.txt:
|
||||
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
|
||||
- Same table coverage and anonymization as iptables.txt
|
||||
- Omitted when ip6tables is not installed or no IPv6 rules are present
|
||||
|
||||
ipset.txt:
|
||||
- Output of 'ipset list' (family-agnostic)
|
||||
- IP addresses are anonymized; set names and types remain unchanged
|
||||
|
||||
nftables.txt:
|
||||
- Complete nftables ruleset obtained via 'nft -a list ruleset'
|
||||
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
|
||||
- Includes rule handle numbers and packet counters
|
||||
- All tables, chains, and rules are included
|
||||
- Shows packet and byte counters for each rule
|
||||
- All IP addresses are anonymized
|
||||
- Chain names, table names, and other non-sensitive information remain unchanged
|
||||
- All IP addresses are anonymized; chain/table names remain unchanged
|
||||
|
||||
sysctls.txt:
|
||||
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
|
||||
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
|
||||
- Interface names anonymized when --anonymize is set
|
||||
|
||||
IP Rules (Linux only)
|
||||
The ip_rules.txt file contains detailed IP routing rule information:
|
||||
@@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addSysctls(); err != nil {
|
||||
log.Errorf("failed to add sysctls to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
@@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) {
|
||||
// addFirewallRules collects and adds firewall rules to the archive
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
log.Info("Collecting firewall rules")
|
||||
iptablesRules, err := collectIPTablesRules()
|
||||
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
|
||||
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
|
||||
|
||||
ipsetOutput, err := collectIPSets()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
} else {
|
||||
if g.anonymize {
|
||||
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
|
||||
log.Warnf("Failed to add ipset output to bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectIPTablesRules collects rules using both iptables-save and verbose listing
|
||||
func collectIPTablesRules() (string, error) {
|
||||
var builder strings.Builder
|
||||
|
||||
saveOutput, err := collectIPTablesSave()
|
||||
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
|
||||
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
|
||||
rules, err := collectIPTablesRules(saveBin, listBin)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== iptables-save output ===\n")
|
||||
log.Warnf("Failed to collect %s rules: %v", listBin, err)
|
||||
return
|
||||
}
|
||||
if g.anonymize {
|
||||
rules = g.anonymizer.AnonymizeString(rules)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
|
||||
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
|
||||
}
|
||||
}
|
||||
|
||||
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
|
||||
// Returns an error when neither command produced any output (e.g. the binary is missing),
|
||||
// so the caller can skip writing an empty file.
|
||||
func collectIPTablesRules(saveBin, listBin string) (string, error) {
|
||||
var builder strings.Builder
|
||||
var collected bool
|
||||
var firstErr error
|
||||
|
||||
saveOutput, err := runCommand(saveBin)
|
||||
switch {
|
||||
case err != nil:
|
||||
firstErr = err
|
||||
log.Warnf("Failed to collect %s output: %v", saveBin, err)
|
||||
case strings.TrimSpace(saveOutput) == "":
|
||||
log.Debugf("%s produced no output, skipping", saveBin)
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
|
||||
builder.WriteString(saveOutput)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
ipsetOutput, err := collectIPSets()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== ipset list output ===\n")
|
||||
builder.WriteString(ipsetOutput)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
|
||||
builder.WriteString(listHeader)
|
||||
|
||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||
|
||||
for _, table := range tables {
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
|
||||
stats, err := getTableStatistics(table)
|
||||
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
builder.WriteString(stats)
|
||||
builder.WriteString("\n")
|
||||
collected = true
|
||||
}
|
||||
|
||||
if !collected {
|
||||
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
@@ -214,34 +238,15 @@ func collectIPSets() (string, error) {
|
||||
return ipsets, nil
|
||||
}
|
||||
|
||||
// collectIPTablesSave uses iptables-save to get rule definitions
|
||||
func collectIPTablesSave() (string, error) {
|
||||
cmd := exec.Command("iptables-save")
|
||||
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
|
||||
func runCommand(name string, args ...string) (string, error) {
|
||||
cmd := exec.Command(name, args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
rules := stdout.String()
|
||||
if strings.TrimSpace(rules) == "" {
|
||||
return "", fmt.Errorf("no iptables rules found")
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// getTableStatistics gets verbose statistics for an entire table using iptables command
|
||||
func getTableStatistics(table string) (string, error) {
|
||||
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
|
||||
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
@@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
|
||||
return fmt.Sprintf("type-%v", keyType)
|
||||
}
|
||||
}
|
||||
|
||||
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
|
||||
func (g *BundleGenerator) addSysctls() error {
|
||||
log.Info("Collecting sysctls")
|
||||
content := collectSysctls()
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
|
||||
return fmt.Errorf("add sysctls to bundle: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectSysctls reads every sysctl that the netbird client may modify, plus
|
||||
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
|
||||
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
|
||||
func collectSysctls() string {
|
||||
var builder strings.Builder
|
||||
|
||||
writeSysctlGroup(&builder, "forwarding", []string{
|
||||
"net.ipv4.ip_forward",
|
||||
"net.ipv6.conf.all.forwarding",
|
||||
"net.ipv6.conf.default.forwarding",
|
||||
})
|
||||
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
|
||||
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
|
||||
writeSysctlGroup(&builder, "rp_filter", append(
|
||||
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
|
||||
listInterfaceSysctls("ipv4", "rp_filter")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "src_valid_mark", append(
|
||||
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
|
||||
listInterfaceSysctls("ipv4", "src_valid_mark")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "conntrack", []string{
|
||||
"net.netfilter.nf_conntrack_acct",
|
||||
"net.netfilter.nf_conntrack_tcp_loose",
|
||||
})
|
||||
writeSysctlGroup(&builder, "tcp", []string{
|
||||
"net.ipv4.tcp_tw_reuse",
|
||||
})
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
|
||||
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
|
||||
for _, key := range keys {
|
||||
value, err := readSysctl(key)
|
||||
if err != nil {
|
||||
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
|
||||
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
|
||||
// (callers add those explicitly so they appear first).
|
||||
func listInterfaceSysctls(family, leaf string) []string {
|
||||
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var keys []string
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if name == "all" || name == "default" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func readSysctl(key string) (string, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
value, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(value)), nil
|
||||
}
|
||||
|
||||
@@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error {
|
||||
// IP rules are only supported on Linux
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addSysctls() error {
|
||||
// Sysctl collection is only supported on Linux
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,10 @@ type hostManager interface {
|
||||
restoreHostDNS() error
|
||||
supportCustomPort() bool
|
||||
string() string
|
||||
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
|
||||
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
|
||||
// hardcoded Quad9 on iOS, nil for noop / mock.
|
||||
getOriginalNameservers() []netip.Addr
|
||||
}
|
||||
|
||||
type SystemDNSSettings struct {
|
||||
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
|
||||
func (n noopHostConfigurator) string() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// androidHostManager is a noop on the OS side (Android's VPN service handles
|
||||
// DNS for us) but tracks the OS-reported resolver list pushed via
|
||||
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
|
||||
type androidHostManager struct {
|
||||
holder *hostsDNSHolder
|
||||
}
|
||||
|
||||
func newHostManager() (*androidHostManager, error) {
|
||||
return &androidHostManager{}, nil
|
||||
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
|
||||
return &androidHostManager{holder: holder}, nil
|
||||
}
|
||||
|
||||
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||
@@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool {
|
||||
func (a androidHostManager) string() string {
|
||||
return "none"
|
||||
}
|
||||
|
||||
func (a androidHostManager) getOriginalNameservers() []netip.Addr {
|
||||
hosts := a.holder.get()
|
||||
out := make([]netip.Addr, 0, len(hosts))
|
||||
for ap := range hosts {
|
||||
out = append(out, ap.Addr())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package dns
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
|
||||
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
|
||||
return []netip.Addr{
|
||||
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
|
||||
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
|
||||
}
|
||||
}
|
||||
|
||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
jsonData, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -44,9 +45,11 @@ const (
|
||||
|
||||
nrptMaxDomainsPerRule = 50
|
||||
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
|
||||
// Network interface DNS registration settings
|
||||
disableDynamicUpdateKey = "DisableDynamicUpdate"
|
||||
@@ -67,10 +70,11 @@ const (
|
||||
)
|
||||
|
||||
type registryConfigurator struct {
|
||||
guid string
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
guid string
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
@@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
gpo: useGPO,
|
||||
}
|
||||
|
||||
origNameservers, err := configurator.captureOriginalNameservers()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
|
||||
case len(origNameservers) == 0:
|
||||
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
|
||||
default:
|
||||
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
|
||||
}
|
||||
configurator.origNameservers = origNameservers
|
||||
|
||||
if err := configurator.configureInterface(); err != nil {
|
||||
log.Errorf("failed to configure interface settings: %v", err)
|
||||
}
|
||||
@@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
|
||||
// registry key except the WG adapter. v4 and v6 servers live in separate
|
||||
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
|
||||
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||
seen := make(map[netip.Addr]struct{})
|
||||
var out []netip.Addr
|
||||
var merr *multierror.Error
|
||||
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
|
||||
addrs, err := r.captureFromTcpipRoot(root)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if _, dup := seen[addr]; dup {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
}
|
||||
return out, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
|
||||
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open key: %w", err)
|
||||
}
|
||||
defer closer(root)
|
||||
|
||||
guids, err := root.ReadSubKeyNames(-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read subkeys: %w", err)
|
||||
}
|
||||
|
||||
var out []netip.Addr
|
||||
for _, guid := range guids {
|
||||
if strings.EqualFold(guid, r.guid) {
|
||||
continue
|
||||
}
|
||||
out = append(out, readInterfaceNameservers(rootPath, guid)...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
|
||||
keyPath := rootPath + "\\" + guid
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closer(k)
|
||||
|
||||
// Static NameServer wins over DhcpNameServer for actual resolution.
|
||||
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
|
||||
raw, _, err := k.GetStringValue(name)
|
||||
if err != nil || raw == "" {
|
||||
continue
|
||||
}
|
||||
if out := parseRegistryNameservers(raw); len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRegistryNameservers(raw string) []netip.Addr {
|
||||
var out []netip.Addr
|
||||
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
|
||||
addr, err := netip.ParseAddr(strings.TrimSpace(field))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
if !addr.IsValid() || addr.IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
// Drop unzoned link-local: not routable without a scope id. If
|
||||
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
|
||||
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, addr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(r.origNameservers)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
||||
h.mutex.RLock()
|
||||
l := h.unprotectedDNSList
|
||||
|
||||
@@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
func (d *Resolver) ProbeAvailability(context.Context) {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
|
||||
return make([]string, 0)
|
||||
}
|
||||
|
||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||
func (m *MockServer) ProbeAvailability() {
|
||||
}
|
||||
|
||||
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
if m.UpdateServerConfigFunc != nil {
|
||||
return m.UpdateServerConfigFunc(domains)
|
||||
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
|
||||
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||
// SetRouteSources mock implementation of SetRouteSources from Server interface
|
||||
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -32,6 +33,15 @@ const (
|
||||
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
|
||||
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
|
||||
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
|
||||
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
|
||||
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
|
||||
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
|
||||
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
|
||||
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
|
||||
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
|
||||
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
|
||||
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
|
||||
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
|
||||
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
|
||||
networkManagerDbusIPv4Key = "ipv4"
|
||||
networkManagerDbusIPv6Key = "ipv6"
|
||||
@@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{
|
||||
}
|
||||
|
||||
type networkManagerDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
|
||||
|
||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
||||
|
||||
return &networkManagerDbusConfigurator{
|
||||
c := &networkManagerDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
ifaceName: wgInterface,
|
||||
}, nil
|
||||
}
|
||||
|
||||
origNameservers, err := c.captureOriginalNameservers()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warnf("capture original nameservers from NetworkManager: %v", err)
|
||||
case len(origNameservers) == 0:
|
||||
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
|
||||
default:
|
||||
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
|
||||
}
|
||||
c.origNameservers = origNameservers
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// captureOriginalNameservers reads DNS servers from every NM device's
|
||||
// IP4Config / IP6Config except our WG device.
|
||||
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||
devices, err := networkManagerListDevices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list devices: %w", err)
|
||||
}
|
||||
|
||||
seen := make(map[netip.Addr]struct{})
|
||||
var out []netip.Addr
|
||||
for _, dev := range devices {
|
||||
if dev == n.dbusLinkObject {
|
||||
continue
|
||||
}
|
||||
ifaceName := readNetworkManagerDeviceInterface(dev)
|
||||
for _, addr := range readNetworkManagerDeviceDNS(dev) {
|
||||
addr = addr.Unmap()
|
||||
if !addr.IsValid() || addr.IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
// IP6Config.Nameservers is a byte slice without zone info;
|
||||
// reattach the device's interface name so a captured fe80::…
|
||||
// stays routable.
|
||||
if addr.IsLinkLocalUnicast() && ifaceName != "" {
|
||||
addr = addr.WithZone(ifaceName)
|
||||
}
|
||||
if _, dup := seen[addr]; dup {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.Value().(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
var devs []dbus.ObjectPath
|
||||
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return devs, nil
|
||||
}
|
||||
|
||||
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
var out []netip.Addr
|
||||
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
|
||||
out = append(out, readIPv4ConfigDNS(path)...)
|
||||
}
|
||||
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
|
||||
out = append(out, readIPv6ConfigDNS(path)...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
|
||||
v, err := obj.GetProperty(property)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
path, ok := v.Value().(dbus.ObjectPath)
|
||||
if !ok || path == "/" {
|
||||
return ""
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
|
||||
// legacy uint32 Nameservers property.
|
||||
if out := readIPv4NameserverData(obj); len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
return readIPv4LegacyNameservers(obj)
|
||||
}
|
||||
|
||||
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
|
||||
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
entries, ok := v.Value().([]map[string]dbus.Variant)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var out []netip.Addr
|
||||
for _, entry := range entries {
|
||||
addrVar, ok := entry["address"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
s, ok := addrVar.Value().(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if a, err := netip.ParseAddr(s); err == nil {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
|
||||
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := v.Value().([]uint32)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]netip.Addr, 0, len(raw))
|
||||
for _, n := range raw {
|
||||
var b [4]byte
|
||||
binary.LittleEndian.PutUint32(b[:], n)
|
||||
out = append(out, netip.AddrFrom4(b))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := v.Value().([][]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]netip.Addr, 0, len(raw))
|
||||
for _, b := range raw {
|
||||
if a, ok := netip.AddrFromSlice(b); ok {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(n.origNameservers)
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager()
|
||||
return newHostManager(s.hostsDNSHolder)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
@@ -31,8 +32,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -101,16 +104,17 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
return &upstreamResolverBase{
|
||||
domain: domain,
|
||||
upstreamServers: srvs,
|
||||
cancel: func() {},
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
ctx: context.Background(),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: hostManager,
|
||||
currentConfig: HostDNSConfig{
|
||||
Domains: []DomainConfig{
|
||||
{false, "domain0", false},
|
||||
{false, "domain1", false},
|
||||
{false, "domain2", false},
|
||||
},
|
||||
},
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
}
|
||||
|
||||
var domainsUpdate string
|
||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
||||
domains := []string{}
|
||||
for _, item := range config.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
domainsUpdate = strings.Join(domains, ",")
|
||||
return nil
|
||||
}
|
||||
|
||||
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
||||
Domains: []string{"domain1"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
}, nil, 0)
|
||||
|
||||
deactivate(nil)
|
||||
expected := "domain0,domain2"
|
||||
domains := []string{}
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got := strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, got)
|
||||
}
|
||||
|
||||
reactivate()
|
||||
expected = "domain0,domain1,domain2"
|
||||
domains = []string{}
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got = strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
skipUnlessAndroid(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
skipUnlessAndroid(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
skipUnlessAndroid(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
|
||||
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
|
||||
// + androidHostManager). On non-android the desktop host manager replaces it
|
||||
// during Initialize and the assertion stops making sense. Skipped here until we
|
||||
// have an android CI runner.
|
||||
func skipUnlessAndroid(t *testing.T) {
|
||||
t.Helper()
|
||||
if runtime.GOOS != "android" {
|
||||
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
|
||||
}
|
||||
}
|
||||
|
||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
t.Helper()
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
@@ -1065,7 +1017,6 @@ type mockHandler struct {
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) Stop() {}
|
||||
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||
}
|
||||
|
||||
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
|
||||
// admin-defined nameserver groups targeting the same domain collapse into a
|
||||
// single handler with each group preserved as a sequential inner list.
|
||||
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
||||
wgInterface := &mocWGIface{}
|
||||
service := NewServiceViaMemory(wgInterface)
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: wgInterface,
|
||||
service: service,
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
groups := []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
|
||||
assert.Equal(t, "example.com", muxUpdates[0].domain)
|
||||
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
|
||||
|
||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
|
||||
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
|
||||
assert.Equal(t, upstreamRace{
|
||||
netip.MustParseAddrPort("192.0.2.2:53"),
|
||||
netip.MustParseAddrPort("192.0.2.3:53"),
|
||||
}, handler.upstreamServers[1])
|
||||
}
|
||||
|
||||
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
|
||||
// (overlay route selected-but-no-active-peer) is intentionally NOT an
|
||||
// input to the evaluator anymore: the verdict drives the Enabled flag,
|
||||
// which must always reflect what we actually observed. Gate-aware event
|
||||
// suppression is tested separately in the projection test.
|
||||
//
|
||||
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
|
||||
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
|
||||
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
|
||||
// fresh-working → Unhealthy; otherwise Undecided.
|
||||
func TestEvaluateNSGroupHealth(t *testing.T) {
|
||||
now := time.Now()
|
||||
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
|
||||
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
|
||||
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
|
||||
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
|
||||
okThenFail := UpstreamHealth{
|
||||
LastOk: now.Add(-10 * time.Second),
|
||||
LastFail: now.Add(-1 * time.Second),
|
||||
LastErr: "timeout",
|
||||
}
|
||||
failThenOk := UpstreamHealth{
|
||||
LastOk: now.Add(-1 * time.Second),
|
||||
LastFail: now.Add(-10 * time.Second),
|
||||
LastErr: "timeout",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
health map[netip.AddrPort]UpstreamHealth
|
||||
servers []netip.AddrPort
|
||||
wantVerdict nsGroupVerdict
|
||||
wantErrSubst string
|
||||
}{
|
||||
{
|
||||
name: "no record, undecided",
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUndecided,
|
||||
},
|
||||
{
|
||||
name: "fresh success, healthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictHealthy,
|
||||
},
|
||||
{
|
||||
name: "fresh failure, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "timeout",
|
||||
},
|
||||
{
|
||||
name: "only stale success, undecided",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUndecided,
|
||||
},
|
||||
{
|
||||
name: "only stale failure, undecided",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUndecided,
|
||||
},
|
||||
{
|
||||
name: "both fresh, fail newer, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "timeout",
|
||||
},
|
||||
{
|
||||
name: "both fresh, ok newer, healthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictHealthy,
|
||||
},
|
||||
{
|
||||
name: "two upstreams, one success wins",
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
a: recentFail,
|
||||
b: recentOk,
|
||||
},
|
||||
servers: []netip.AddrPort{a, b},
|
||||
wantVerdict: nsVerdictHealthy,
|
||||
},
|
||||
{
|
||||
name: "two upstreams, one fail one unseen, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
a: recentFail,
|
||||
},
|
||||
servers: []netip.AddrPort{a, b},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "timeout",
|
||||
},
|
||||
{
|
||||
name: "two upstreams, all recent failures, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
|
||||
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
|
||||
},
|
||||
servers: []netip.AddrPort{a, b},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "SERVFAIL",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
|
||||
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
|
||||
if tc.wantErrSubst != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.wantErrSubst)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
||||
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||
// without spinning up real handlers.
|
||||
type healthStubHandler struct {
|
||||
health map[netip.AddrPort]UpstreamHealth
|
||||
}
|
||||
|
||||
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (h *healthStubHandler) Stop() {}
|
||||
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
|
||||
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
return h.health
|
||||
}
|
||||
|
||||
// TestProjection_SteadyStateIsSilent guards against duplicate events:
|
||||
// while a group stays Unhealthy tick after tick, only the first
|
||||
// Unhealthy transition may emit. Same for staying Healthy.
|
||||
func TestProjection_SteadyStateIsSilent(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "first fail emits warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.tick()
|
||||
fx.expectNoEvent("staying unhealthy must not re-emit")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "recovery on transition")
|
||||
|
||||
fx.tick()
|
||||
fx.tick()
|
||||
fx.expectNoEvent("staying healthy must not re-emit")
|
||||
}
|
||||
|
||||
// projTestFixture is the common setup for the projection tests: a
|
||||
// single-upstream group whose route classification the test can flip by
|
||||
// assigning to selected/active. Callers drive failures/successes by
|
||||
// mutating stub.health and calling refreshHealth.
|
||||
type projTestFixture struct {
|
||||
t *testing.T
|
||||
recorder *peer.Status
|
||||
events <-chan *proto.SystemEvent
|
||||
server *DefaultServer
|
||||
stub *healthStubHandler
|
||||
group *nbdns.NameServerGroup
|
||||
srv netip.AddrPort
|
||||
selected route.HAMap
|
||||
active route.HAMap
|
||||
}
|
||||
|
||||
func newProjTestFixture(t *testing.T) *projTestFixture {
|
||||
t.Helper()
|
||||
recorder := peer.NewRecorder("mgm")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||
|
||||
srv := netip.MustParseAddrPort("100.64.0.1:53")
|
||||
fx := &projTestFixture{
|
||||
t: t,
|
||||
recorder: recorder,
|
||||
events: sub.Events(),
|
||||
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
|
||||
srv: srv,
|
||||
group: &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
},
|
||||
}
|
||||
fx.server = &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||
activeRoutes: func() route.HAMap { return fx.active },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
}
|
||||
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
||||
|
||||
fx.server.mux.Lock()
|
||||
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||
fx.server.mux.Unlock()
|
||||
return fx
|
||||
}
|
||||
|
||||
func (f *projTestFixture) setHealth(h UpstreamHealth) {
|
||||
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
|
||||
}
|
||||
|
||||
func (f *projTestFixture) tick() []peer.NSGroupState {
|
||||
f.server.refreshHealth()
|
||||
return f.recorder.GetDNSStates()
|
||||
}
|
||||
|
||||
func (f *projTestFixture) expectNoEvent(why string) {
|
||||
f.t.Helper()
|
||||
select {
|
||||
case evt := <-f.events:
|
||||
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
|
||||
f.t.Helper()
|
||||
select {
|
||||
case evt := <-f.events:
|
||||
assert.Contains(f.t, evt.Message, substr, why)
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
f.t.Fatalf("expected event (%s) with %q", why, substr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
|
||||
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
|
||||
|
||||
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
|
||||
// that is not inside any selected route (public DNS) fires the warning
|
||||
// on the first Unhealthy tick, no grace period.
|
||||
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.False(t, states[0].Enabled)
|
||||
fx.expectEvent("unreachable", "public DNS failure")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
|
||||
// the upstream is inside a selected route AND the route has a Connected
|
||||
// peer. Tunnel is up, failure is real, emit immediately.
|
||||
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.selected = overlayMapForTest
|
||||
fx.active = overlayMapForTest
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.False(t, states[0].Enabled)
|
||||
fx.expectEvent("unreachable", "overlay + connected failure")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
|
||||
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
|
||||
// First tick: Unhealthy display, no warning. After the grace window
|
||||
// elapses with no recovery, the warning fires.
|
||||
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
|
||||
grace := 50 * time.Millisecond
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = grace
|
||||
fx.selected = overlayMapForTest
|
||||
// active stays nil: routed but not connected.
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
|
||||
fx.expectNoEvent("first fail tick within grace window")
|
||||
|
||||
time.Sleep(grace + 10*time.Millisecond)
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "warning after grace window")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
|
||||
// whose address is inside the WireGuard overlay range but is not
|
||||
// covered by any selected route (peer-to-peer DNS without an explicit
|
||||
// route). Until a peer reports Connected for that address, startup
|
||||
// failures must be held just like the routed case.
|
||||
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
recorder := peer.NewRecorder("mgm")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||
|
||||
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: 50 * time.Millisecond,
|
||||
}
|
||||
group := &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
server.mux.Unlock()
|
||||
server.refreshHealth()
|
||||
|
||||
select {
|
||||
case evt := <-sub.Events():
|
||||
t.Fatalf("unexpected event during grace window: %+v", evt)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
|
||||
server.refreshHealth()
|
||||
|
||||
select {
|
||||
case evt := <-sub.Events():
|
||||
assert.Contains(t, evt.Message, "unreachable")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected warning after grace window")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProjection_StopClearsHealthState verifies that Stop wipes the
|
||||
// per-group projection state so a subsequent Start doesn't inherit
|
||||
// sticky flags (notably everHealthy) that would bypass the grace
|
||||
// window during the next peer handshake.
|
||||
func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
wgIface := &mocWGIface{}
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: wgIface,
|
||||
service: NewServiceViaMemory(wgIface),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
extraDomains: map[domain.Domain]int{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
currentConfigHash: ^uint64(0),
|
||||
}
|
||||
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
srv := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
group := &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
server.mux.Unlock()
|
||||
server.refreshHealth()
|
||||
|
||||
server.healthProjectMu.Lock()
|
||||
p, ok := server.nsGroupProj[generateGroupKey(group)]
|
||||
server.healthProjectMu.Unlock()
|
||||
require.True(t, ok, "projection state should exist after tick")
|
||||
require.True(t, p.everHealthy, "tick with success must set everHealthy")
|
||||
|
||||
server.Stop()
|
||||
|
||||
server.healthProjectMu.Lock()
|
||||
cleared := server.nsGroupProj == nil
|
||||
server.healthProjectMu.Unlock()
|
||||
assert.True(t, cleared, "Stop must clear nsGroupProj")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
|
||||
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||
// comes up and a query succeeds before the grace window elapses. No
|
||||
// warning should ever have fired, and no recovery either.
|
||||
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||
fx.selected = overlayMapForTest
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectNoEvent("fail within grace, warning suppressed")
|
||||
|
||||
fx.active = overlayMapForTest
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.True(t, states[0].Enabled)
|
||||
fx.expectNoEvent("recovery without prior warning must not emit")
|
||||
}
|
||||
|
||||
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
|
||||
// whole design leans on: recovery events only appear when a warning
|
||||
// event was actually emitted for the current streak. A Healthy verdict
|
||||
// without a prior warning is silent, so the user never sees "recovered"
|
||||
// out of thin air.
|
||||
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.True(t, states[0].Enabled)
|
||||
fx.expectNoEvent("first healthy tick should not recover anything")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "public fail emits immediately")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "recovery follows real warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "second cycle warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "second cycle recovery")
|
||||
}
|
||||
|
||||
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
|
||||
// has ever been Healthy, subsequent failures skip the grace window even
|
||||
// if classification says "routed + not connected". The system has
|
||||
// proved it can work, so any new failure is real.
|
||||
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
|
||||
fx.server.warningDelayBase = time.Hour
|
||||
fx.selected = overlayMapForTest
|
||||
fx.active = overlayMapForTest
|
||||
|
||||
// Establish "ever healthy".
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectNoEvent("first healthy tick")
|
||||
|
||||
// Peer drops. Query fails. Routed + not connected → normally grace,
|
||||
// but everHealthy flag bypasses it.
|
||||
fx.active = nil
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
|
||||
}
|
||||
|
||||
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
|
||||
// from the design discussion: once a group has been healthy, a brief
|
||||
// reconnect that produces a failing tick will fire warning + recovery.
|
||||
// This is by design: user-visible blips are accurate signal, not noise.
|
||||
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.selected = overlayMapForTest
|
||||
fx.active = overlayMapForTest
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "blip warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "blip recovery")
|
||||
}
|
||||
|
||||
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
|
||||
// rule: a group with at least one public upstream is in the "immediate"
|
||||
// category regardless of the other upstreams' routing, because the
|
||||
// public one has no peer-startup excuse. Prevents public-DNS failures
|
||||
// from being hidden behind a routed sibling.
|
||||
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
recorder := peer.NewRecorder("mgm")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||
events := sub.Events()
|
||||
|
||||
public := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
overlay := netip.MustParseAddrPort("100.64.0.1:53")
|
||||
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: time.Hour,
|
||||
}
|
||||
group := &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
|
||||
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
|
||||
},
|
||||
}
|
||||
stub := &healthStubHandler{
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
public: {LastFail: time.Now(), LastErr: "servfail"},
|
||||
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
},
|
||||
}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
server.mux.Unlock()
|
||||
server.refreshHealth()
|
||||
|
||||
select {
|
||||
case evt := <-events:
|
||||
assert.Contains(t, evt.Message, "unreachable")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected immediate warning because group contains a public upstream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSLoopPrevention(t *testing.T) {
|
||||
wgInterface := &mocWGIface{}
|
||||
service := NewServiceViaMemory(wgInterface)
|
||||
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
|
||||
|
||||
if tt.expectedHandlers > 0 {
|
||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
|
||||
flat := handler.flatUpstreams()
|
||||
assert.Len(t, flat, len(tt.expectedServers))
|
||||
|
||||
if tt.shouldFilterOwnIP {
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
for _, upstream := range flat {
|
||||
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
for _, expected := range tt.expectedServers {
|
||||
found := false
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
for _, upstream := range flat {
|
||||
if upstream.Addr() == expected {
|
||||
found = true
|
||||
break
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
@@ -40,10 +41,17 @@ const (
|
||||
)
|
||||
|
||||
type systemdDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
ifaceName string
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
ifaceName string
|
||||
wgIndex int
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
const (
|
||||
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
|
||||
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
|
||||
)
|
||||
|
||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
|
||||
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
|
||||
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
|
||||
|
||||
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
||||
|
||||
return &systemdDbusConfigurator{
|
||||
c := &systemdDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
ifaceName: wgInterface,
|
||||
}, nil
|
||||
wgIndex: iface.Index,
|
||||
}
|
||||
|
||||
origNameservers, err := c.captureOriginalNameservers()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
|
||||
case len(origNameservers) == 0:
|
||||
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
|
||||
default:
|
||||
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
|
||||
}
|
||||
c.origNameservers = origNameservers
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
|
||||
// every default-route link except our own WG link. Non-default-route links
|
||||
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
|
||||
// actually serve host queries.
|
||||
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list interfaces: %w", err)
|
||||
}
|
||||
|
||||
seen := make(map[netip.Addr]struct{})
|
||||
var out []netip.Addr
|
||||
for _, iface := range ifaces {
|
||||
if !s.isCandidateLink(iface) {
|
||||
continue
|
||||
}
|
||||
linkPath, err := getSystemdLinkPath(iface.Index)
|
||||
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
|
||||
continue
|
||||
}
|
||||
for _, addr := range readSystemdLinkDNS(linkPath) {
|
||||
addr = normalizeSystemdAddr(addr, iface.Name)
|
||||
if !addr.IsValid() {
|
||||
continue
|
||||
}
|
||||
if _, dup := seen[addr]; dup {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
|
||||
if iface.Index == s.wgIndex {
|
||||
return false
|
||||
}
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
|
||||
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
|
||||
// Returns the zero Addr to signal "skip this entry".
|
||||
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
|
||||
addr = addr.Unmap()
|
||||
if !addr.IsValid() || addr.IsUnspecified() {
|
||||
return netip.Addr{}
|
||||
}
|
||||
if addr.IsLinkLocalUnicast() {
|
||||
return addr.WithZone(ifaceName)
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("dbus resolve1: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
var p string
|
||||
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dbus.ObjectPath(p), nil
|
||||
}
|
||||
|
||||
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
b, ok := v.Value().(bool)
|
||||
return ok && b
|
||||
}
|
||||
|
||||
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
entries, ok := v.Value().([][]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var out []netip.Addr
|
||||
for _, entry := range entries {
|
||||
if len(entry) < 2 {
|
||||
continue
|
||||
}
|
||||
raw, ok := entry[1].([]byte)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(raw)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, addr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(s.origNameservers)
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
||||
|
||||
@@ -1,3 +1,32 @@
|
||||
// Package dns implements the client-side DNS stack: listener/service on the
|
||||
// peer's tunnel address, handler chain that routes questions by domain and
|
||||
// priority, and upstream resolvers that forward what remains to configured
|
||||
// nameservers.
|
||||
//
|
||||
// # Upstream resolution and the race model
|
||||
//
|
||||
// When two or more nameserver groups target the same domain, DefaultServer
|
||||
// merges them into one upstream handler whose state is:
|
||||
//
|
||||
// upstreamResolverBase
|
||||
// └── upstreamServers []upstreamRace // one entry per source NS group
|
||||
// └── []netip.AddrPort // primary, fallback, ...
|
||||
//
|
||||
// Each source nameserver group contributes one upstreamRace. Within a race
|
||||
// upstreams are tried in order: the next is used only on failure (timeout,
|
||||
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
|
||||
// the walk. When more than one race exists, ServeDNS fans out one
|
||||
// goroutine per race and returns the first valid answer, cancelling the
|
||||
// rest. A handler with a single race skips the fan-out.
|
||||
//
|
||||
// # Health projection
|
||||
//
|
||||
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
|
||||
// periodically merges these snapshots across handlers and projects them
|
||||
// into peer.NSGroupState. There is no active probing: a group is marked
|
||||
// unhealthy only when every seen upstream has a recent failure and none
|
||||
// has a recent success. Healthy→unhealthy fires a single
|
||||
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
|
||||
package dns
|
||||
|
||||
import (
|
||||
@@ -11,11 +40,8 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
@@ -25,7 +51,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
var currentMTU uint16 = iface.DefaultMTU
|
||||
@@ -67,15 +94,17 @@ const (
|
||||
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
||||
ClientTimeout = 5 * time.Second
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
|
||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||
// payload from the tunnel MTU.
|
||||
ipUDPHeaderSize = 60 + 8
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
|
||||
// within one race, so a slow primary can't eat the whole race budget.
|
||||
raceMaxTotalTimeout = 5 * time.Second
|
||||
// raceMinPerUpstreamTimeout is the floor applied when dividing
|
||||
// raceMaxTotalTimeout across upstreams within a race.
|
||||
raceMinPerUpstreamTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
protoUDP = "udp"
|
||||
@@ -84,6 +113,69 @@ const (
|
||||
|
||||
type dnsProtocolKey struct{}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
type UpstreamResolver interface {
|
||||
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
// upstreamRace is an ordered list of upstreams derived from one configured
|
||||
// nameserver group. Order matters: the first upstream is tried first, the
|
||||
// second only on failure, and so on. Multiple upstreamRace values coexist
|
||||
// inside one resolver when overlapping nameserver groups target the same
|
||||
// domain; those races run in parallel and the first valid answer wins.
|
||||
type upstreamRace []netip.AddrPort
|
||||
|
||||
// UpstreamHealth is the last query-path outcome for a single upstream,
|
||||
// consumed by nameserver-group status projection.
|
||||
type UpstreamHealth struct {
|
||||
LastOk time.Time
|
||||
LastFail time.Time
|
||||
LastErr string
|
||||
}
|
||||
|
||||
type upstreamResolverBase struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []upstreamRace
|
||||
domain domain.Domain
|
||||
upstreamTimeout time.Duration
|
||||
|
||||
healthMu sync.RWMutex
|
||||
health map[netip.AddrPort]*UpstreamHealth
|
||||
|
||||
statusRecorder *peer.Status
|
||||
// selectedRoutes returns the current set of client routes the admin
|
||||
// has enabled. Called lazily from the query hot path when an upstream
|
||||
// might need a tunnel-bound client (iOS) and from health projection.
|
||||
selectedRoutes func() route.HAMap
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
upstream netip.AddrPort
|
||||
reason string
|
||||
}
|
||||
|
||||
type raceResult struct {
|
||||
msg *dns.Msg
|
||||
upstream netip.AddrPort
|
||||
protocol string
|
||||
ede string
|
||||
failures []upstreamFailure
|
||||
}
|
||||
|
||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||
@@ -100,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
@@ -124,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
|
||||
}
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
type UpstreamResolver interface {
|
||||
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
type upstreamResolverBase struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []netip.AddrPort
|
||||
domain string
|
||||
disabled bool
|
||||
successCount atomic.Int32
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
wg sync.WaitGroup
|
||||
|
||||
deactivate func(error)
|
||||
reactivate func()
|
||||
statusRecorder *peer.Status
|
||||
routeMatch func(netip.Addr) bool
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
upstream netip.AddrPort
|
||||
reason string
|
||||
}
|
||||
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: domain,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
statusRecorder: statusRecorder,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: d,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the upstream resolver
|
||||
func (u *upstreamResolverBase) String() string {
|
||||
return fmt.Sprintf("Upstream %s", u.upstreamServers)
|
||||
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
// ID returns the unique handler ID. Race groupings and within-race
|
||||
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
|
||||
// the same servers but with different semantics (serial fallback vs
|
||||
// parallel race), so their handlers must not collide.
|
||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||
servers := slices.Clone(u.upstreamServers)
|
||||
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
|
||||
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(u.domain + ":"))
|
||||
for _, s := range servers {
|
||||
hash.Write([]byte(s.String()))
|
||||
hash.Write([]byte("|"))
|
||||
hash.Write([]byte(u.domain.PunycodeString() + ":"))
|
||||
for _, race := range u.upstreamServers {
|
||||
hash.Write([]byte("["))
|
||||
for _, s := range race {
|
||||
hash.Write([]byte(s.String()))
|
||||
hash.Write([]byte("|"))
|
||||
}
|
||||
hash.Write([]byte("]"))
|
||||
}
|
||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
}
|
||||
@@ -194,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) Stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
|
||||
u.cancel()
|
||||
}
|
||||
|
||||
u.mutex.Lock()
|
||||
u.wg.Wait()
|
||||
u.mutex.Unlock()
|
||||
// flatUpstreams is for logging and ID hashing only, not for dispatch.
|
||||
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
|
||||
var out []netip.AddrPort
|
||||
for _, g := range u.upstreamServers {
|
||||
out = append(out, g...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// setSelectedRoutes swaps the accessor used to classify overlay-routed
|
||||
// upstreams. Called when route sources are wired after the handler was
|
||||
// built (permanent / iOS constructors).
|
||||
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
|
||||
u.selectedRoutes = selected
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
|
||||
if len(servers) == 0 {
|
||||
return
|
||||
}
|
||||
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
@@ -242,82 +314,201 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
minPerUpstream := 2 * time.Second
|
||||
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
|
||||
if scaledTimeout > minPerUpstream {
|
||||
timeout = scaledTimeout
|
||||
} else {
|
||||
timeout = minPerUpstream
|
||||
}
|
||||
groups := u.upstreamServers
|
||||
switch len(groups) {
|
||||
case 0:
|
||||
return false, nil
|
||||
case 1:
|
||||
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
|
||||
default:
|
||||
return u.raceAll(ctx, w, r, groups, logger)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
res := u.tryRace(ctx, r, group)
|
||||
if res.msg == nil {
|
||||
return false, res.failures
|
||||
}
|
||||
if res.ede != "" {
|
||||
resutil.SetMeta(w, "ede", res.ede)
|
||||
}
|
||||
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||
return true, res.failures
|
||||
}
|
||||
|
||||
// raceAll runs one worker per group in parallel, taking the first valid
|
||||
// answer and cancelling the rest.
|
||||
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
raceCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Buffer sized to len(groups) so workers never block on send, even
|
||||
// after the coordinator has returned.
|
||||
results := make(chan raceResult, len(groups))
|
||||
for _, g := range groups {
|
||||
// tryRace clones the request per attempt, so workers never share
|
||||
// a *dns.Msg and concurrent EDNS0 mutations can't race.
|
||||
go func(g upstreamRace) {
|
||||
results <- u.tryRace(raceCtx, r, g)
|
||||
}(g)
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
for range groups {
|
||||
select {
|
||||
case res := <-results:
|
||||
failures = append(failures, res.failures...)
|
||||
if res.msg != nil {
|
||||
if res.ede != "" {
|
||||
resutil.SetMeta(w, "ede", res.ede)
|
||||
}
|
||||
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||
return true, failures
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return false, failures
|
||||
}
|
||||
}
|
||||
return false, failures
|
||||
}
|
||||
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(group) > 1 {
|
||||
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
|
||||
// still honor raceMinPerUpstreamTimeout as a floor for correctness
|
||||
// on slow links, but the outer context ensures the combined walk
|
||||
// cannot exceed the cap regardless of group size.
|
||||
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range group {
|
||||
if ctx.Err() != nil {
|
||||
return raceResult{failures: failures}
|
||||
}
|
||||
// Clone the request per attempt: the exchange path mutates EDNS0
|
||||
// options in-place, so reusing the same *dns.Msg across sequential
|
||||
// upstreams would carry those mutations (e.g. a reduced UDP size)
|
||||
// into the next attempt.
|
||||
res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
|
||||
if failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
continue
|
||||
}
|
||||
res.failures = failures
|
||||
return res
|
||||
}
|
||||
return raceResult{failures: failures}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
|
||||
|
||||
// Advertise EDNS0 so the upstream may include Extended DNS Errors
|
||||
// (RFC 8914) in failure responses; we use those to short-circuit
|
||||
// failover for definitive answers like DNSSEC validation failures.
|
||||
// Operate on a copy so the inbound request is unchanged: a client that
|
||||
// did not advertise EDNS0 must not see an OPT in the response.
|
||||
// The caller already passed a per-attempt copy, so we can mutate r
|
||||
// directly; hadEdns reflects the original client request's state and
|
||||
// controls whether we strip the OPT from the response.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
reqUp := r
|
||||
if !hadEdns {
|
||||
reqUp = r.Copy()
|
||||
reqUp.SetEdns0(upstreamUDPSize(), false)
|
||||
r.SetEdns0(upstreamUDPSize(), false)
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
|
||||
}()
|
||||
startTime := time.Now()
|
||||
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
|
||||
if err != nil {
|
||||
return u.handleUpstreamError(err, upstream, startTime)
|
||||
// A parent cancellation (e.g., another race won and the coordinator
|
||||
// cancelled the losers) is not an upstream failure. Check both the
|
||||
// error chain and the parent context: a transport may surface the
|
||||
// cancellation as a read/deadline error rather than context.Canceled.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"}
|
||||
}
|
||||
failure := u.handleUpstreamError(err, upstream, startTime)
|
||||
u.markUpstreamFail(upstream, failure.reason)
|
||||
return raceResult{}, failure
|
||||
}
|
||||
|
||||
if rm == nil || !rm.Response {
|
||||
return &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
u.markUpstreamFail(upstream, "no response")
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
proto := ""
|
||||
if upstreamProto != nil {
|
||||
proto = upstreamProto.protocol
|
||||
}
|
||||
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
resutil.SetMeta(w, "ede", edeName(code))
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||
}
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
reason := dns.RcodeToString[rm.Rcode]
|
||||
u.markUpstreamFail(upstream, reason)
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||
}
|
||||
|
||||
// healthEntry returns the mutable health record for addr, lazily creating
|
||||
// the map and the entry. Caller must hold u.healthMu.
|
||||
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
|
||||
if u.health == nil {
|
||||
u.health = make(map[netip.AddrPort]*UpstreamHealth)
|
||||
}
|
||||
h := u.health[addr]
|
||||
if h == nil {
|
||||
h = &UpstreamHealth{}
|
||||
u.health[addr] = h
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
|
||||
u.healthMu.Lock()
|
||||
defer u.healthMu.Unlock()
|
||||
h := u.healthEntry(addr)
|
||||
h.LastOk = time.Now()
|
||||
h.LastFail = time.Time{}
|
||||
h.LastErr = ""
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
|
||||
u.healthMu.Lock()
|
||||
defer u.healthMu.Unlock()
|
||||
h := u.healthEntry(addr)
|
||||
h.LastFail = time.Now()
|
||||
h.LastErr = reason
|
||||
}
|
||||
|
||||
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
|
||||
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
u.healthMu.RLock()
|
||||
defer u.healthMu.RUnlock()
|
||||
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
|
||||
for k, v := range u.health {
|
||||
out[k] = *v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
|
||||
@@ -358,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
||||
if u.statusRecorder == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
||||
if peerInfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
||||
if proto != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", proto)
|
||||
}
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
@@ -372,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
|
||||
|
||||
if err := w.WriteMsg(rm); err != nil {
|
||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
||||
totalUpstreams := len(u.upstreamServers)
|
||||
totalUpstreams := len(u.flatUpstreams())
|
||||
failedCount := len(failures)
|
||||
failureSummary := formatFailures(failures)
|
||||
|
||||
@@ -434,119 +633,6 @@ func edeName(code uint16) string {
|
||||
return fmt.Sprintf("EDE %d", code)
|
||||
}
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
// avoid probe if upstreams could resolve at least one query
|
||||
if u.successCount.Load() > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var success bool
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var errs *multierror.Error
|
||||
for _, upstream := range u.upstreamServers {
|
||||
wg.Add(1)
|
||||
go func(upstream netip.AddrPort) {
|
||||
defer wg.Done()
|
||||
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errs = multierror.Append(errs, err)
|
||||
mu.Unlock()
|
||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
success = true
|
||||
mu.Unlock()
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// didn't find a working upstream server, let's disable and try later
|
||||
if !success {
|
||||
u.disable(errs.ErrorOrNil())
|
||||
|
||||
if u.statusRecorder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_WARNING,
|
||||
proto.SystemEvent_DNS,
|
||||
"All upstream servers failed (probe failed)",
|
||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||
map[string]string{"upstreams": u.upstreamServersString()},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
|
||||
func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
exponentialBackOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.1,
|
||||
MaxInterval: u.reactivatePeriod,
|
||||
MaxElapsedTime: 0,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
operation := func() error {
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
|
||||
default:
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
|
||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||
} else {
|
||||
// at least one upstream server is available, stop probing
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
|
||||
return fmt.Errorf("upstream check call error")
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||
} else {
|
||||
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.mutex.Lock()
|
||||
u.disabled = false
|
||||
u.mutex.Unlock()
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
//
|
||||
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
||||
@@ -558,45 +644,6 @@ func isTimeout(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) disable(err error) {
|
||||
if u.disabled {
|
||||
return
|
||||
}
|
||||
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
u.wg.Add(1)
|
||||
go func() {
|
||||
defer u.wg.Done()
|
||||
u.waitUntilResponse()
|
||||
}()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
var servers []string
|
||||
for _, server := range u.upstreamServers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
return strings.Join(servers, ", ")
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
|
||||
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
|
||||
defer cancel()
|
||||
|
||||
if externalCtx != nil {
|
||||
stop2 := context.AfterFunc(externalCtx, cancel)
|
||||
defer stop2()
|
||||
}
|
||||
|
||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||
|
||||
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
|
||||
return err
|
||||
}
|
||||
|
||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||
func clientUDPMaxSize(r *dns.Msg) int {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
@@ -608,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
@@ -634,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
t time.Duration
|
||||
err error
|
||||
)
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
rm, t, err := client.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||
}
|
||||
@@ -659,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
@@ -681,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
|
||||
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
|
||||
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
|
||||
// so net.Dialer doesn't reject the TCP dial with "mismatched local
|
||||
// address type".
|
||||
func toTCPClient(c *dns.Client) *dns.Client {
|
||||
tcp := *c
|
||||
tcp.Net = protoTCP
|
||||
if tcp.Dialer == nil {
|
||||
return &tcp
|
||||
}
|
||||
d := *tcp.Dialer
|
||||
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
|
||||
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
|
||||
}
|
||||
tcp.Dialer = &d
|
||||
return &tcp
|
||||
}
|
||||
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
@@ -822,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
|
||||
return bestMatch
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
||||
if u.statusRecorder == nil {
|
||||
return ""
|
||||
// haMapRouteCount returns the total number of routes across all HA
|
||||
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
|
||||
// routes per key, so len(hm) is the number of HA groups, not routes.
|
||||
func haMapRouteCount(hm route.HAMap) int {
|
||||
total := 0
|
||||
for _, routes := range hm {
|
||||
total += len(routes)
|
||||
}
|
||||
|
||||
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
||||
if peerInfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||
return total
|
||||
}
|
||||
|
||||
// haMapContains checks whether ip is covered by any concrete prefix in
|
||||
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
|
||||
// routes carry a placeholder Network that can't be prefix-checked, so we
|
||||
// can't know at this point whether ip is reached through one. Callers
|
||||
// decide how to interpret the unknown: health projection treats it as
|
||||
// "possibly routed" to avoid emitting false-positive warnings during
|
||||
// startup, while iOS dial selection requires a concrete match before
|
||||
// binding to the tunnel.
|
||||
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
|
||||
for _, routes := range hm {
|
||||
for _, r := range routes {
|
||||
if r.IsDynamic() {
|
||||
haveDynamic = true
|
||||
continue
|
||||
}
|
||||
if r.Network.Contains(ip) {
|
||||
return true, haveDynamic
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, haveDynamic
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
@@ -26,9 +27,9 @@ func newUpstreamResolver(
|
||||
_ WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
d domain.Domain,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||
c := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
hostsDNSHolder: hostsDNSHolder,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
@@ -24,9 +25,9 @@ func newUpstreamResolver(
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
d domain.Domain,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||
nonIOS := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
nsNet: wgIface.GetNet(),
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
@@ -27,9 +28,9 @@ func newUpstreamResolver(
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
d domain.Domain,
|
||||
) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
@@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
upstreamIP = upstreamIP.Unmap()
|
||||
}
|
||||
addr := u.wgIface.Address()
|
||||
var routed bool
|
||||
if u.selectedRoutes != nil {
|
||||
// Only a concrete prefix match binds to the tunnel: dialing
|
||||
// through a private client for an upstream we can't prove is
|
||||
// routed would break public resolvers.
|
||||
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
|
||||
}
|
||||
needsPrivate := addr.Network.Contains(upstreamIP) ||
|
||||
addr.IPv6Net.Contains(upstreamIP) ||
|
||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||
routed
|
||||
if needsPrivate {
|
||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
|
||||
@@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
}
|
||||
}
|
||||
|
||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||
return ExchangeWithFallback(nil, client, r, upstream)
|
||||
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
||||
}
|
||||
}
|
||||
resolver.upstreamServers = servers
|
||||
resolver.addRace(servers)
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
cancel()
|
||||
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type mockUpstreamResolver struct {
|
||||
r *dns.Msg
|
||||
rtt time.Duration
|
||||
err error
|
||||
}
|
||||
|
||||
// exchange mock implementation of exchange from upstreamResolver
|
||||
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
return c.r, c.rtt, c.err
|
||||
}
|
||||
|
||||
type mockUpstreamResponse struct {
|
||||
msg *dns.Msg
|
||||
err error
|
||||
msg *dns.Msg
|
||||
err error
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
type mockUpstreamResolverPerServer struct {
|
||||
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
if r, ok := c.responses[upstream]; ok {
|
||||
return r.msg, c.rtt, r.err
|
||||
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
r, ok := c.responses[upstream]
|
||||
if !ok {
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolver{
|
||||
err: dns.ErrTime,
|
||||
r: new(dns.Msg),
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: context.TODO(),
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: time.Microsecond * 100,
|
||||
}
|
||||
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
|
||||
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
|
||||
|
||||
failed := false
|
||||
resolver.deactivate = func(error) {
|
||||
failed = true
|
||||
// After deactivation, make the mock client work again
|
||||
mockClient.err = nil
|
||||
}
|
||||
|
||||
reactivated := false
|
||||
resolver.reactivate = func() {
|
||||
reactivated = true
|
||||
}
|
||||
|
||||
resolver.ProbeAvailability(context.TODO())
|
||||
|
||||
if !failed {
|
||||
t.Errorf("expected that resolving was deactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if !resolver.disabled {
|
||||
t.Errorf("resolver should be Disabled")
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
|
||||
if !reactivated {
|
||||
t.Errorf("expected that resolving was reactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if resolver.disabled {
|
||||
t.Errorf("should be enabled")
|
||||
if r.delay > 0 {
|
||||
select {
|
||||
case <-time.After(r.delay):
|
||||
case <-ctx.Done():
|
||||
return nil, c.rtt, ctx.Err()
|
||||
}
|
||||
}
|
||||
return r.msg, c.rtt, r.err
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: trackingClient,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamServers: []netip.AddrPort{upstream},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{upstream})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
|
||||
// configured for the same domain, with one broken group. The merge+race
|
||||
// path should answer as fast as the working group and not pay the timeout
|
||||
// of the broken one on every query.
|
||||
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
|
||||
broken := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
working := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
successAnswer := "192.0.2.100"
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
// Force the broken upstream to only unblock via timeout /
|
||||
// cancellation so the assertion below can't pass if races
|
||||
// were run serially.
|
||||
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
|
||||
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: 250 * time.Millisecond,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{broken})
|
||||
resolver.addRace([]netip.AddrPort{working})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
start := time.Now()
|
||||
resolver.ServeDNS(responseWriter, inputMSG)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.NotNil(t, responseMSG, "should write a response")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
|
||||
require.NotEmpty(t, responseMSG.Answer)
|
||||
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
|
||||
// Working group answers in a single RTT; the broken group's
|
||||
// timeout (100ms) must not block the response.
|
||||
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
|
||||
// resolver returns SERVFAIL rather than leaking a partial response.
|
||||
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
|
||||
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{a})
|
||||
resolver.addRace([]netip.AddrPort{b})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
require.NotNil(t, responseMSG)
|
||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_HealthTracking verifies that query-path results are
|
||||
// recorded into per-upstream health, which is what projects back to
|
||||
// NSGroupState for status reporting.
|
||||
func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
||||
ok := netip.MustParseAddrPort("192.0.2.10:53")
|
||||
bad := netip.MustParseAddrPort("192.0.2.11:53")
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
|
||||
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{ok, bad})
|
||||
|
||||
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
|
||||
health := resolver.UpstreamHealth()
|
||||
require.Contains(t, health, ok)
|
||||
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
|
||||
assert.Empty(t, health[ok].LastErr)
|
||||
|
||||
// bad upstream was never tried because ok answered first; its health
|
||||
// should remain unset.
|
||||
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
var receivedUDPSize atomic.Uint32
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
receivedUDPSize.Store(uint32(opt.UDPSize()))
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
@@ -874,7 +951,7 @@ func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: tracking,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamServers: []upstreamRace{{upstream1, upstream2}},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/auth/sessionwatch"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
@@ -250,6 +251,8 @@ type Engine struct {
|
||||
jobExecutorWG sync.WaitGroup
|
||||
|
||||
exposeManager *expose.Manager
|
||||
|
||||
sessionWatcher *sessionwatch.Watcher
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -293,6 +296,17 @@ func NewEngine(
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
}
|
||||
// sessionWatcher keeps the SubscribeStatus consumers in sync with the
|
||||
// session expiry deadline. Deadline-change ticks come for free via
|
||||
// Status.SetSessionExpiresAt; the watcher exists to push a wake-up at
|
||||
// T-WarningLead and T-FinalWarningLead so the UI repaints the remaining
|
||||
// time / warning state even when nothing else changed, and to publish
|
||||
// two SystemEvents (the warning composition lives in sessionwatch so
|
||||
// the wire format stays owned by one package):
|
||||
// - T-WarningLead → interactive "Extend now / Dismiss" notification
|
||||
// - T-FinalWarningLead → auto-opened SessionAboutToExpire dialog,
|
||||
// suppressed when the user dismissed the earlier warning
|
||||
engine.sessionWatcher = sessionwatch.New(engine.statusRecorder)
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
return engine
|
||||
@@ -333,6 +347,10 @@ func (e *Engine) Stop() error {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
|
||||
if e.sessionWatcher != nil {
|
||||
e.sessionWatcher.Close()
|
||||
}
|
||||
|
||||
if e.updateManager != nil {
|
||||
e.updateManager.SetDownloadOnly()
|
||||
}
|
||||
@@ -512,16 +530,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
||||
for _, r := range routes {
|
||||
if r.Network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
@@ -874,6 +883,8 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
e.ApplySessionDeadline(update.GetSessionExpiresAt())
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
@@ -1386,9 +1397,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||
go e.dnsServer.ProbeAvailability()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1932,7 +1940,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
return dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
|
||||
default:
|
||||
|
||||
97
client/internal/engine_authsession.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
// ApplySessionDeadline propagates the absolute SSO session deadline carried on
|
||||
// LoginResponse / SyncResponse to both the watcher (for the edge-triggered
|
||||
// warning) and the status recorder (for the SubscribeStatus / Status RPC
|
||||
// snapshot the UI consumes).
|
||||
//
|
||||
// The wire field is 3-state:
|
||||
// - nil → snapshot carries no info; keep the
|
||||
// previously-anchored deadline (no-op)
|
||||
// - explicit zero (s=0, n=0) → peer is not SSO-registered or expiry is
|
||||
// disabled; clear both sinks
|
||||
// - valid timestamp → new deadline; arm watcher, expose on
|
||||
// status recorder
|
||||
//
|
||||
// Deadline sanity-checks live in sessionwatch.Watcher.Update. Any rejected
|
||||
// value is treated as a clear on both sinks: the alternative — leaving the
|
||||
// previously-known deadline in place — risks the UI confidently displaying
|
||||
// a stale "expires in X" while the server has actually invalidated it.
|
||||
func (e *Engine) ApplySessionDeadline(ts *timestamppb.Timestamp) {
|
||||
if ts == nil {
|
||||
return
|
||||
}
|
||||
var deadline time.Time
|
||||
// Explicit zero (seconds=0 AND nanos=0) is the sentinel for "disabled".
|
||||
// Everything else flows through Watcher.Update, whose sanity-checks
|
||||
// reject out-of-range / pre-epoch / far-future / too-stale values; the
|
||||
// catch-block below converts any rejection into a clear.
|
||||
if ts.GetSeconds() != 0 || ts.GetNanos() != 0 {
|
||||
deadline = ts.AsTime().UTC()
|
||||
}
|
||||
if e.sessionWatcher != nil {
|
||||
if err := e.sessionWatcher.Update(deadline); err != nil {
|
||||
log.Errorf("auth session deadline rejected: %v, clearing", err)
|
||||
deadline = time.Time{}
|
||||
}
|
||||
}
|
||||
if e.statusRecorder != nil {
|
||||
e.statusRecorder.SetSessionExpiresAt(deadline)
|
||||
}
|
||||
}
|
||||
|
||||
// DismissSessionWarning records the user's "Dismiss" click on the
|
||||
// T-WarningLead interactive notification and suppresses the upcoming
|
||||
// T-FinalWarningLead fallback for the current deadline. No-op when the
|
||||
// watcher is not running or holds no deadline.
|
||||
func (e *Engine) DismissSessionWarning() {
|
||||
if e.sessionWatcher == nil {
|
||||
return
|
||||
}
|
||||
e.sessionWatcher.Dismiss()
|
||||
}
|
||||
|
||||
// ExtendAuthSession asks the management server to refresh the SSO session
|
||||
// expiry deadline using the supplied JWT, then mirrors the new deadline into
|
||||
// the daemon's state. The tunnel is untouched; no resync, no reconnect.
|
||||
//
|
||||
// Returns the new absolute UTC deadline (or zero time when the server
|
||||
// reports the peer is not eligible for extension).
|
||||
func (e *Engine) ExtendAuthSession(ctx context.Context, jwtToken string) (time.Time, error) {
|
||||
if jwtToken == "" {
|
||||
return time.Time{}, errors.New("jwt token is required")
|
||||
}
|
||||
if e.mgmClient == nil {
|
||||
return time.Time{}, errors.New("management client is not initialised")
|
||||
}
|
||||
|
||||
info, err := system.GetInfoWithChecks(ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to collect system info for session extend: %v", err)
|
||||
info = system.GetInfo(ctx)
|
||||
}
|
||||
|
||||
resp, err := e.mgmClient.ExtendAuthSession(info, jwtToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("extend auth session on management: %w", err)
|
||||
}
|
||||
|
||||
e.ApplySessionDeadline(resp.GetSessionExpiresAt())
|
||||
|
||||
if resp.GetSessionExpiresAt().IsValid() {
|
||||
return resp.GetSessionExpiresAt().AsTime().UTC(), nil
|
||||
}
|
||||
return time.Time{}, nil
|
||||
}
|
||||
78
client/internal/engine_session_deadline_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth/sessionwatch"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
// TestApplySessionDeadline_ThreeState pins down the 3-state semantics of the
|
||||
// wire field carried on LoginResponse / SyncResponse:
|
||||
//
|
||||
// - nil pointer → no info; previously-anchored deadline survives
|
||||
// - explicit zero value → "expiry disabled" sentinel; both sinks cleared
|
||||
// - valid future timestamp → new deadline propagated to both sinks
|
||||
func TestApplySessionDeadline_ThreeState(t *testing.T) {
|
||||
newEngine := func() *Engine {
|
||||
recorder := peer.NewRecorder("")
|
||||
return &Engine{
|
||||
statusRecorder: recorder,
|
||||
sessionWatcher: sessionwatch.New(recorder),
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("valid timestamp sets deadline on both sinks", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
deadline := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
|
||||
e.ApplySessionDeadline(timestamppb.New(deadline))
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(deadline),
|
||||
"status recorder should hold the new deadline")
|
||||
})
|
||||
|
||||
t.Run("nil is a no-op and preserves previous deadline", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
e.ApplySessionDeadline(timestamppb.New(seeded))
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
|
||||
|
||||
e.ApplySessionDeadline(nil)
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded),
|
||||
"nil snapshot must not disturb the existing deadline")
|
||||
})
|
||||
|
||||
t.Run("explicit zero clears a previously-anchored deadline", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
e.ApplySessionDeadline(timestamppb.New(seeded))
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
|
||||
|
||||
// Explicit zero Timestamp{} (seconds=0, nanos=0) is the
|
||||
// "expiry disabled / not SSO" sentinel.
|
||||
e.ApplySessionDeadline(×tamppb.Timestamp{})
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().IsZero(),
|
||||
"explicit zero sentinel must clear the deadline")
|
||||
})
|
||||
|
||||
t.Run("invalid timestamp clears the deadline", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
e.ApplySessionDeadline(timestamppb.New(seeded))
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
|
||||
|
||||
// Out-of-range nanos → IsValid()==false; same-meaning as the
|
||||
// disabled sentinel for downstream sinks.
|
||||
e.ApplySessionDeadline(×tamppb.Timestamp{Seconds: 1, Nanos: -1})
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().IsZero(),
|
||||
"invalid timestamp must clear the deadline")
|
||||
})
|
||||
}
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -900,7 +899,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
}
|
||||
|
||||
// Fallback to deterministic key if no NetBird PSK is configured
|
||||
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
|
||||
determKey, err := conn.rosenpassDetermKey()
|
||||
if err != nil {
|
||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
return nil
|
||||
@@ -909,6 +908,26 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
return determKey
|
||||
}
|
||||
|
||||
// todo: move this logic into Rosenpass package
|
||||
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
|
||||
lk := []byte(conn.config.LocalKey)
|
||||
rk := []byte(conn.config.Key) // remote key
|
||||
var keyInput []byte
|
||||
if string(lk) > string(rk) {
|
||||
//nolint:gocritic
|
||||
keyInput = append(lk[:16], rk[:16]...)
|
||||
} else {
|
||||
//nolint:gocritic
|
||||
keyInput = append(rk[:16], lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
||||
@@ -202,6 +202,12 @@ type Status struct {
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
// sessionExpiresAt is the absolute UTC instant at which the peer's SSO
|
||||
// session expires. Zero when the peer is not SSO-tracked or login
|
||||
// expiration is disabled. Populated from management LoginResponse /
|
||||
// SyncResponse and exposed via the daemon's Status / SubscribeStatus RPC
|
||||
// so the UI can show remaining time without itself talking to mgm.
|
||||
sessionExpiresAt time.Time
|
||||
nsGroupStates []NSGroupState
|
||||
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
||||
lazyConnectionEnabled bool
|
||||
@@ -217,6 +223,14 @@ type Status struct {
|
||||
eventStreams map[string]chan *proto.SystemEvent
|
||||
eventQueue *EventQueue
|
||||
|
||||
// stateChangeStreams fan-out connection-state changes (connected /
|
||||
// disconnected / connecting / address change / peers list change) to
|
||||
// every active SubscribeStatus gRPC stream. Each subscriber gets a
|
||||
// buffered chan; the notifier non-blockingly pings them so a slow
|
||||
// consumer can never stall the daemon.
|
||||
stateChangeMux sync.Mutex
|
||||
stateChangeStreams map[string]chan struct{}
|
||||
|
||||
ingressGwMgr *ingressgw.Manager
|
||||
|
||||
routeIDLookup routeIDLookup
|
||||
@@ -230,6 +244,7 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
|
||||
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||
eventQueue: NewEventQueue(eventQueueSize),
|
||||
stateChangeStreams: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
@@ -360,6 +375,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -385,6 +401,7 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -410,6 +427,7 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -459,6 +477,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -495,6 +514,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -530,6 +550,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -568,6 +589,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -661,6 +683,7 @@ func (d *Status) FinishPeerListModifications() {
|
||||
for _, rd := range dispatches {
|
||||
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription {
|
||||
@@ -719,6 +742,32 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// SetSessionExpiresAt records the absolute UTC instant at which the peer's
|
||||
// SSO session is set to expire. Pass the zero value to clear (e.g. when the
|
||||
// management server stops publishing a deadline because login expiration was
|
||||
// disabled or the peer is not SSO-tracked). Same-value updates are no-ops;
|
||||
// real changes fan out via notifyStateChange so SubscribeStatus consumers
|
||||
// pick up the new deadline on their next read.
|
||||
func (d *Status) SetSessionExpiresAt(deadline time.Time) {
|
||||
d.mux.Lock()
|
||||
if d.sessionExpiresAt.Equal(deadline) {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.sessionExpiresAt = deadline
|
||||
d.mux.Unlock()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// GetSessionExpiresAt returns the most recently recorded SSO session deadline,
|
||||
// or the zero value when no deadline is tracked.
|
||||
func (d *Status) GetSessionExpiresAt() time.Time {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return d.sessionExpiresAt
|
||||
}
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
@@ -787,6 +836,7 @@ func (d *Status) CleanLocalPeerState() {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
@@ -799,6 +849,7 @@ func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
@@ -811,6 +862,7 @@ func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
@@ -851,6 +903,7 @@ func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
@@ -863,6 +916,7 @@ func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
@@ -1060,16 +1114,19 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
// ClientStart will notify all listeners about the new service state
|
||||
func (d *Status) ClientStart() {
|
||||
d.notifier.clientStart()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// ClientStop will notify all listeners about the new service state
|
||||
func (d *Status) ClientStop() {
|
||||
d.notifier.clientStop()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// ClientTeardown will notify all listeners about the service is under teardown
|
||||
func (d *Status) ClientTeardown() {
|
||||
d.notifier.clientTearDown()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// SetConnectionListener set a listener to the notifier
|
||||
@@ -1211,6 +1268,62 @@ func (d *Status) GetEventHistory() []*proto.SystemEvent {
|
||||
return d.eventQueue.GetAll()
|
||||
}
|
||||
|
||||
// SubscribeToStateChanges hands back a channel that receives a tick on
|
||||
// every connection-state change (connected / disconnected / connecting /
|
||||
// address change / peers-list change). The channel is buffered to one
|
||||
// pending tick so a coalesced burst still wakes the consumer exactly
|
||||
// once. Pass the returned id to UnsubscribeFromStateChanges to detach.
|
||||
func (d *Status) SubscribeToStateChanges() (string, <-chan struct{}) {
|
||||
d.stateChangeMux.Lock()
|
||||
defer d.stateChangeMux.Unlock()
|
||||
|
||||
id := uuid.New().String()
|
||||
ch := make(chan struct{}, 1)
|
||||
d.stateChangeStreams[id] = ch
|
||||
return id, ch
|
||||
}
|
||||
|
||||
// UnsubscribeFromStateChanges releases a SubscribeToStateChanges channel
|
||||
// and closes it so any consumer goroutine selecting on the channel
|
||||
// unblocks cleanly.
|
||||
func (d *Status) UnsubscribeFromStateChanges(id string) {
|
||||
d.stateChangeMux.Lock()
|
||||
defer d.stateChangeMux.Unlock()
|
||||
|
||||
if ch, ok := d.stateChangeStreams[id]; ok {
|
||||
close(ch)
|
||||
delete(d.stateChangeStreams, id)
|
||||
}
|
||||
}
|
||||
|
||||
// notifyStateChange wakes every SubscribeToStateChanges subscriber. Drops
|
||||
// the tick if a subscriber's buffer is full — by definition the consumer
|
||||
// is already going to fetch the latest snapshot, so multiple pending ticks
|
||||
// would be redundant.
|
||||
func (d *Status) notifyStateChange() {
|
||||
d.stateChangeMux.Lock()
|
||||
defer d.stateChangeMux.Unlock()
|
||||
|
||||
for _, ch := range d.stateChangeStreams {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyStateChange is the public wake-the-subscribers entry point used by
|
||||
// callers that mutate state outside the peer recorder — most importantly
|
||||
// the connect-state machine, which writes StatusNeedsLogin into the
|
||||
// shared contextState (client/internal/state.go) without touching any
|
||||
// recorder field. Without this push the SubscribeStatus stream stays on
|
||||
// the previous snapshot until an unrelated peer/management/signal
|
||||
// change happens to fire notifyStateChange, leaving the UI's status
|
||||
// out of sync with the daemon.
|
||||
func (d *Status) NotifyStateChange() {
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -28,15 +28,6 @@ func hashRosenpassKey(key []byte) string {
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
|
||||
// so tests can substitute a mock without spinning up a real UDP server.
|
||||
type rpServer interface {
|
||||
AddPeer(rp.PeerConfig) (rp.PeerID, error)
|
||||
RemovePeer(rp.PeerID) error
|
||||
Run() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
ifaceName string
|
||||
spk []byte
|
||||
@@ -45,7 +36,7 @@ type Manager struct {
|
||||
preSharedKey *[32]byte
|
||||
rpPeerIDs map[string]*rp.PeerID
|
||||
rpWgHandler *NetbirdHandler
|
||||
server rpServer
|
||||
server *rp.Server
|
||||
lock sync.Mutex
|
||||
port int
|
||||
wgIface PresharedKeySetter
|
||||
@@ -60,22 +51,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
||||
|
||||
rpKeyHash := hashRosenpassKey(public)
|
||||
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||
return &Manager{
|
||||
ifaceName: wgIfaceName,
|
||||
rpKeyHash: rpKeyHash,
|
||||
spk: public,
|
||||
ssk: secret,
|
||||
preSharedKey: (*[32]byte)(preSharedKey),
|
||||
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||
// rpWgHandler is created here (instead of only in generateConfig) so it
|
||||
// is never nil between NewManager and Run(). Otherwise an early
|
||||
// OnConnected call (race observed on Android, issue #4341) panics on
|
||||
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
|
||||
// replace it with a fresh handler on each Run() to clear stale peer
|
||||
// state from previous engine sessions.
|
||||
rpWgHandler: NewNetbirdHandler(),
|
||||
lock: sync.Mutex{},
|
||||
}, nil
|
||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetPubKey() []byte {
|
||||
@@ -89,16 +65,6 @@ func (m *Manager) GetAddress() *net.UDPAddr {
|
||||
|
||||
// addPeer adds a new peer to the Rosenpass server
|
||||
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
|
||||
// Defense in depth against issue #4341 (Android crash): if Run() has not
|
||||
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
|
||||
// error instead of panicking on nil-receiver dereference.
|
||||
if m.server == nil {
|
||||
return fmt.Errorf("rosenpass server not initialized")
|
||||
}
|
||||
if m.rpWgHandler == nil {
|
||||
return fmt.Errorf("rosenpass wg handler not initialized")
|
||||
}
|
||||
|
||||
var err error
|
||||
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
|
||||
if m.preSharedKey != nil {
|
||||
@@ -113,16 +79,6 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
|
||||
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
|
||||
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
|
||||
}
|
||||
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
|
||||
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
|
||||
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
|
||||
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
|
||||
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
|
||||
// IPv6 so its address family matches our listening socket.
|
||||
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
|
||||
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
|
||||
pcfg.Endpoint.IP = v4.To16()
|
||||
}
|
||||
}
|
||||
peerID, err := m.server.AddPeer(pcfg)
|
||||
if err != nil {
|
||||
@@ -226,31 +182,24 @@ func (m *Manager) Run() error {
|
||||
return err
|
||||
}
|
||||
|
||||
server, err := rp.NewUDPServer(conf)
|
||||
m.server, err = rp.NewUDPServer(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
m.server = server
|
||||
m.lock.Unlock()
|
||||
|
||||
log.Infof("starting rosenpass server on port %d", m.port)
|
||||
|
||||
return server.Run()
|
||||
return m.server.Run()
|
||||
}
|
||||
|
||||
// Close closes the Rosenpass server
|
||||
func (m *Manager) Close() error {
|
||||
m.lock.Lock()
|
||||
server := m.server
|
||||
m.server = nil
|
||||
m.lock.Unlock()
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
if err := server.Close(); err != nil {
|
||||
log.Errorf("failed closing local rosenpass server: %v", err)
|
||||
if m.server != nil {
|
||||
err := m.server.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed closing local rosenpass server")
|
||||
}
|
||||
m.server = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,412 +1,14 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
rp "cunicu.li/go-rosenpass"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// --- test doubles -----------------------------------------------------------
|
||||
|
||||
type addPeerCall struct {
|
||||
cfg rp.PeerConfig
|
||||
}
|
||||
|
||||
type removePeerCall struct {
|
||||
id rp.PeerID
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mu sync.Mutex
|
||||
addCalls []addPeerCall
|
||||
removed []removePeerCall
|
||||
nextID rp.PeerID
|
||||
addErr error
|
||||
removeErr error
|
||||
closed bool
|
||||
ran bool
|
||||
}
|
||||
|
||||
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
|
||||
if m.addErr != nil {
|
||||
return rp.PeerID{}, m.addErr
|
||||
}
|
||||
// Increment a byte in nextID so distinct peers get distinct IDs.
|
||||
m.nextID[0]++
|
||||
return m.nextID, nil
|
||||
}
|
||||
|
||||
func (m *mockServer) RemovePeer(id rp.PeerID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.removed = append(m.removed, removePeerCall{id: id})
|
||||
return m.removeErr
|
||||
}
|
||||
|
||||
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||
func (m *mockServer) Close() error { m.closed = true; return nil }
|
||||
|
||||
type setPSKCall struct {
|
||||
peerKey string
|
||||
psk wgtypes.Key
|
||||
updateOnly bool
|
||||
}
|
||||
|
||||
type mockIface struct {
|
||||
mu sync.Mutex
|
||||
calls []setPSKCall
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
|
||||
return m.err
|
||||
}
|
||||
|
||||
// newTestManager builds a Manager with deterministic spk so tie-break
|
||||
// against a peer pubkey is controllable from tests. The provided spk byte
|
||||
// becomes the first byte; remaining bytes are zero.
|
||||
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
|
||||
spk := make([]byte, 32)
|
||||
spk[0] = spkFirstByte
|
||||
return &Manager{
|
||||
ifaceName: "wt0",
|
||||
spk: spk,
|
||||
ssk: make([]byte, 32),
|
||||
rpKeyHash: "test-hash",
|
||||
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||
rpWgHandler: NewNetbirdHandler(),
|
||||
server: mock,
|
||||
}
|
||||
}
|
||||
|
||||
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
|
||||
func validWGKey(t *testing.T, lastByte byte) string {
|
||||
t.Helper()
|
||||
var k wgtypes.Key
|
||||
k[31] = lastByte
|
||||
return k.String()
|
||||
}
|
||||
|
||||
// --- pure helpers ----------------------------------------------------------
|
||||
|
||||
func TestHashRosenpassKey_Deterministic(t *testing.T) {
|
||||
key := []byte("hello-rosenpass")
|
||||
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
|
||||
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
|
||||
}
|
||||
|
||||
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
|
||||
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
|
||||
}
|
||||
|
||||
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
|
||||
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
|
||||
// can only set, not delete, so do it manually with restore via t.Cleanup.
|
||||
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
|
||||
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
|
||||
t.Cleanup(func() {
|
||||
if hadPrev {
|
||||
_ = os.Setenv(defaultLogLevelVar, prev)
|
||||
} else {
|
||||
_ = os.Unsetenv(defaultLogLevelVar)
|
||||
}
|
||||
})
|
||||
require.Equal(t, defaultLog.String(), getLogLevel().String())
|
||||
}
|
||||
|
||||
func TestGetLogLevel_Cases(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"debug": "DEBUG",
|
||||
"info": "INFO",
|
||||
"warn": "WARN",
|
||||
"error": "ERROR",
|
||||
"unknown": "INFO", // default fallback
|
||||
}
|
||||
for input, wantStr := range cases {
|
||||
input, wantStr := input, wantStr
|
||||
t.Run(input, func(t *testing.T) {
|
||||
t.Setenv(defaultLogLevelVar, input)
|
||||
require.Equal(t, wantStr, getLogLevel().String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
||||
port, err := findRandomAvailableUDPPort()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, port, 0)
|
||||
require.LessOrEqual(t, port, 65535)
|
||||
}
|
||||
|
||||
// --- addPeer ---------------------------------------------------------------
|
||||
|
||||
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv) // local spk lexicographically larger
|
||||
|
||||
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
|
||||
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, srv.addCalls, 1)
|
||||
|
||||
ep := srv.addCalls[0].cfg.Endpoint
|
||||
require.NotNil(t, ep, "initiator side must set Endpoint")
|
||||
require.Equal(t, 7000, ep.Port)
|
||||
require.Equal(t, "100.1.1.1", ep.IP.String())
|
||||
}
|
||||
|
||||
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
|
||||
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
|
||||
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.NoError(t, err)
|
||||
|
||||
ep := srv.addCalls[0].cfg.Endpoint
|
||||
require.NotNil(t, ep)
|
||||
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
|
||||
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
|
||||
}
|
||||
|
||||
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0x00, srv) // local spk smaller
|
||||
|
||||
remotePubKey := make([]byte, 32)
|
||||
remotePubKey[0] = 0xFF
|
||||
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
|
||||
}
|
||||
|
||||
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
psk := &wgtypes.Key{0x42}
|
||||
m := newTestManager(0xFF, srv)
|
||||
m.preSharedKey = (*[32]byte)(psk)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
|
||||
}
|
||||
|
||||
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
|
||||
}
|
||||
|
||||
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAddPeer_ServerError_Propagates(t *testing.T) {
|
||||
srv := &mockServer{addErr: errors.New("boom")}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// Regression guard for issue #4341 (Android crash). If Run() has not completed
|
||||
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
|
||||
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
|
||||
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
m.rpWgHandler = nil // simulate Run() not yet completed
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "wg handler not initialized")
|
||||
}
|
||||
|
||||
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
|
||||
m := newTestManager(0xFF, nil)
|
||||
m.server = nil // simulate Run() not yet completed
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "server not initialized")
|
||||
}
|
||||
|
||||
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
|
||||
// issue #4341 cannot occur in the window between NewManager and Run().
|
||||
func TestNewManager_PreInitializesHandler(t *testing.T) {
|
||||
psk := wgtypes.Key{}
|
||||
m, err := NewManager(&psk, "wt0")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
|
||||
}
|
||||
|
||||
func TestAddPeer_RecordsPeerID(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 5)
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||
}
|
||||
|
||||
// --- OnConnected / OnDisconnected ------------------------------------------
|
||||
|
||||
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
|
||||
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
|
||||
require.Empty(t, m.rpPeerIDs)
|
||||
}
|
||||
|
||||
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 1)
|
||||
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
|
||||
require.Len(t, srv.addCalls, 1)
|
||||
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||
}
|
||||
|
||||
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
m.OnDisconnected(validWGKey(t, 99))
|
||||
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
|
||||
}
|
||||
|
||||
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 1)
|
||||
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||
|
||||
m.OnDisconnected(wgKey)
|
||||
require.Len(t, srv.removed, 1)
|
||||
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
|
||||
}
|
||||
|
||||
// --- IsPresharedKeyInitialized ---------------------------------------------
|
||||
|
||||
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
|
||||
}
|
||||
|
||||
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 2)
|
||||
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||
require.False(t, m.IsPresharedKeyInitialized(wgKey))
|
||||
}
|
||||
|
||||
// --- NetbirdHandler.outputKey ----------------------------------------------
|
||||
|
||||
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
pid := rp.PeerID{0x01}
|
||||
wgKey := wgtypes.Key{0xAA}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||
|
||||
psk := rp.Key{0xBB}
|
||||
h.HandshakeCompleted(pid, psk)
|
||||
|
||||
require.Len(t, iface.calls, 1)
|
||||
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
|
||||
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||
}
|
||||
|
||||
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
pid := rp.PeerID{0x02}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
|
||||
|
||||
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
|
||||
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
|
||||
|
||||
require.Len(t, iface.calls, 2)
|
||||
require.False(t, iface.calls[0].updateOnly)
|
||||
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
|
||||
}
|
||||
|
||||
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
// no SetInterface — iface remains nil
|
||||
pid := rp.PeerID{0x03}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
|
||||
|
||||
// Must not panic.
|
||||
h.HandshakeCompleted(pid, rp.Key{})
|
||||
}
|
||||
|
||||
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
|
||||
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
|
||||
}
|
||||
|
||||
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
pid := rp.PeerID{0x04}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
|
||||
h.HandshakeCompleted(pid, rp.Key{0x01})
|
||||
require.True(t, h.IsPeerInitialized(pid))
|
||||
|
||||
h.RemovePeer(pid)
|
||||
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
|
||||
}
|
||||
|
||||
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
pid := rp.PeerID{0x05}
|
||||
wgKey := wgtypes.Key{0xEE}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface) // set after AddPeer
|
||||
|
||||
h.HandshakeCompleted(pid, rp.Key{0x42})
|
||||
require.Len(t, iface.calls, 1)
|
||||
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||
}
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
|
||||
// of peer public keys. Both peers, given the same key pair, produce the same
|
||||
// output regardless of which side runs the function: the inputs are ordered
|
||||
// lexicographically before concatenation.
|
||||
//
|
||||
// NetBird uses this value as the initial Rosenpass-side preshared key when no
|
||||
// explicit account-level PSK is configured, so both peers converge on the same
|
||||
// PSK before the first post-quantum handshake completes.
|
||||
//
|
||||
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
|
||||
// from public keys and exists only to seed WireGuard until Rosenpass rotates
|
||||
// in a real post-quantum PSK.
|
||||
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
|
||||
lk := []byte(localKey)
|
||||
rk := []byte(remoteKey)
|
||||
if len(lk) < 16 || len(rk) < 16 {
|
||||
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
|
||||
}
|
||||
|
||||
var keyInput []byte
|
||||
if localKey > remoteKey {
|
||||
keyInput = append(keyInput, lk[:16]...)
|
||||
keyInput = append(keyInput, rk[:16]...)
|
||||
} else {
|
||||
keyInput = append(keyInput, rk[:16]...)
|
||||
keyInput = append(keyInput, lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
|
||||
// Peer A and peer B must derive the same PSK regardless of which side
|
||||
// computes it: the function orders inputs internally.
|
||||
a := strings.Repeat("a", 32)
|
||||
b := strings.Repeat("b", 32)
|
||||
|
||||
keyAB, err := DeterministicSeedKey(a, b)
|
||||
require.NoError(t, err)
|
||||
keyBA, err := DeterministicSeedKey(b, a)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
|
||||
}
|
||||
|
||||
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
|
||||
a := strings.Repeat("a", 32)
|
||||
b := strings.Repeat("b", 32)
|
||||
c := strings.Repeat("c", 32)
|
||||
|
||||
keyAB, err := DeterministicSeedKey(a, b)
|
||||
require.NoError(t, err)
|
||||
keyAC, err := DeterministicSeedKey(a, c)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
|
||||
}
|
||||
|
||||
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
|
||||
short := "short" // < 16 bytes
|
||||
long := strings.Repeat("x", 32)
|
||||
|
||||
_, err := DeterministicSeedKey(short, long)
|
||||
require.Error(t, err)
|
||||
_, err = DeterministicSeedKey(long, short)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ type Manager interface {
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
GetClientRoutes() route.HAMap
|
||||
GetSelectedClientRoutes() route.HAMap
|
||||
GetActiveClientRoutes() route.HAMap
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
@@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
|
||||
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||
}
|
||||
|
||||
// GetActiveClientRoutes returns the subset of selected client routes
|
||||
// that are currently reachable: the route's peer is Connected and is
|
||||
// the one actively carrying the route (not just an HA sibling).
|
||||
func (m *DefaultManager) GetActiveClientRoutes() route.HAMap {
|
||||
m.mux.Lock()
|
||||
selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||
recorder := m.statusRecorder
|
||||
m.mux.Unlock()
|
||||
|
||||
if recorder == nil {
|
||||
return selected
|
||||
}
|
||||
|
||||
out := make(route.HAMap, len(selected))
|
||||
for id, routes := range selected {
|
||||
for _, r := range routes {
|
||||
st, err := recorder.GetPeer(r.Peer)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if st.ConnStatus != peer.StatusConnected {
|
||||
continue
|
||||
}
|
||||
if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute {
|
||||
continue
|
||||
}
|
||||
out[id] = routes
|
||||
break
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
m.mux.Lock()
|
||||
@@ -704,7 +738,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
|
||||
}
|
||||
|
||||
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
|
||||
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
|
||||
if len(routes) == 0 {
|
||||
return false
|
||||
}
|
||||
return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {
|
||||
|
||||
@@ -19,6 +19,7 @@ type MockManager struct {
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
GetSelectedClientRoutesFunc func() route.HAMap
|
||||
GetActiveClientRoutesFunc func() route.HAMap
|
||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
@@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetActiveClientRoutes() route.HAMap {
|
||||
if m.GetActiveClientRoutesFunc != nil {
|
||||
return m.GetActiveClientRoutesFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -12,10 +13,6 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
exitNodeCIDR = "0.0.0.0/0"
|
||||
)
|
||||
|
||||
type RouteSelector struct {
|
||||
mu sync.RWMutex
|
||||
deselectedRoutes map[route.NetID]struct{}
|
||||
@@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
if rs.deselectAll {
|
||||
return false
|
||||
}
|
||||
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
isSelected := !deselected
|
||||
return isSelected
|
||||
return rs.isSelectedLocked(routeID)
|
||||
}
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
@@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
|
||||
filtered := route.HAMap{}
|
||||
for id, rt := range routes {
|
||||
netID := id.NetID()
|
||||
_, deselected := rs.deselectedRoutes[netID]
|
||||
if !deselected {
|
||||
if !rs.isDeselectedLocked(id.NetID()) {
|
||||
filtered[id] = rt
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
|
||||
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
|
||||
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
|
||||
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
|
||||
// transparently apply to the synthesized v6 entry.
|
||||
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return selected || deselected
|
||||
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
|
||||
@@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
|
||||
filtered := make(route.HAMap, len(routes))
|
||||
for id, rt := range routes {
|
||||
netID := id.NetID()
|
||||
if rs.isDeselected(netID) {
|
||||
if rs.isDeselectedLocked(netID) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
|
||||
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
|
||||
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
|
||||
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
|
||||
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
|
||||
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
|
||||
name := string(id)
|
||||
if !strings.HasSuffix(name, route.V6ExitSuffix) {
|
||||
return id
|
||||
}
|
||||
if _, ok := rs.selectedRoutes[id]; ok {
|
||||
return id
|
||||
}
|
||||
if _, ok := rs.deselectedRoutes[id]; ok {
|
||||
return id
|
||||
}
|
||||
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return false
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return !deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return true
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[netID]
|
||||
return deselected || rs.deselectAll
|
||||
return deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return selected || deselected
|
||||
}
|
||||
|
||||
func isExitNode(rt []*route.Route) bool {
|
||||
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
|
||||
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) applyExitNodeFilter(
|
||||
@@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter(
|
||||
rt []*route.Route,
|
||||
out route.HAMap,
|
||||
) {
|
||||
|
||||
if rs.hasUserSelections() {
|
||||
// user made explicit selects/deselects
|
||||
if rs.IsSelected(netID) {
|
||||
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
|
||||
// drops the synthesized v6 entry that lacks its own explicit state.
|
||||
effective := rs.effectiveNetID(netID)
|
||||
if rs.hasUserSelectionForRouteLocked(effective) {
|
||||
if rs.isSelectedLocked(effective) {
|
||||
out[id] = rt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
|
||||
// no explicit selection for this route: defer to management's SkipAutoApply flag
|
||||
sel := collectSelected(rt)
|
||||
if len(sel) > 0 {
|
||||
out[id] = sel
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) hasUserSelections() bool {
|
||||
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
|
||||
}
|
||||
|
||||
func collectSelected(rt []*route.Route) []*route.Route {
|
||||
var sel []*route.Route
|
||||
for _, r := range rt {
|
||||
|
||||
@@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
|
||||
assert.Len(t, filtered, 0) // No routes should be selected
|
||||
}
|
||||
|
||||
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
|
||||
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
|
||||
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
|
||||
// v4 base, so legacy persisted selections that predate v6 pairing transparently
|
||||
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
|
||||
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
|
||||
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
|
||||
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
|
||||
|
||||
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
|
||||
|
||||
// unrelated v6 with no v4 base touched is unaffected
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
|
||||
})
|
||||
|
||||
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||
|
||||
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
|
||||
// via the mirror; the mirror only applies in exit-node code paths.
|
||||
assert.False(t, rs.IsSelected("corp"))
|
||||
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
|
||||
})
|
||||
|
||||
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
|
||||
|
||||
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||
routes := route.HAMap{
|
||||
"exit1|0.0.0.0/0": {v4Route},
|
||||
"exit1-v6|::/0": {v6Route},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
|
||||
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
|
||||
})
|
||||
|
||||
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
// A route literally named "exit1-something" must not pair-resolve.
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
|
||||
})
|
||||
|
||||
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||
routes := route.HAMap{
|
||||
"exit1|0.0.0.0/0": {v4Route},
|
||||
"exit1-v6|::/0": {v6Route},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
|
||||
})
|
||||
|
||||
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||
|
||||
// A non-default-route entry named "corp-v6" is not an exit node and
|
||||
// must not be skipped because its v4 base "corp" is deselected.
|
||||
corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")}
|
||||
routes := route.HAMap{
|
||||
"corp-v6|10.0.0.0/8": {corpV6},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"),
|
||||
"non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's
|
||||
// SkipAutoApply flag governs each untouched route independently, even when
|
||||
// the user has explicit selections on other routes.
|
||||
func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) {
|
||||
autoApplied := &route.Route{
|
||||
NetID: "Auto",
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
SkipAutoApply: false,
|
||||
}
|
||||
skipApply := &route.Route{
|
||||
NetID: "Skip",
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
SkipAutoApply: true,
|
||||
}
|
||||
routes := route.HAMap{
|
||||
"Auto|0.0.0.0/0": {autoApplied},
|
||||
"Skip|0.0.0.0/0": {skipApply},
|
||||
}
|
||||
|
||||
rs := routeselector.NewRouteSelector()
|
||||
// User makes an unrelated explicit selection elsewhere.
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"}))
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included")
|
||||
assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection")
|
||||
}
|
||||
|
||||
// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized
|
||||
// as exit nodes by the selector's filter path.
|
||||
func TestRouteSelector_V6ExitIsExitNode(t *testing.T) {
|
||||
v6Exit := &route.Route{
|
||||
NetID: "V6Only",
|
||||
Network: netip.MustParsePrefix("::/0"),
|
||||
SkipAutoApply: true,
|
||||
}
|
||||
routes := route.HAMap{
|
||||
"V6Only|::/0": {v6Exit},
|
||||
}
|
||||
|
||||
rs := routeselector.NewRouteSelector()
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply")
|
||||
}
|
||||
|
||||
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
||||
initialRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}
|
||||
|
||||
@@ -33,17 +33,34 @@ func CtxGetState(ctx context.Context) *contextState {
|
||||
}
|
||||
|
||||
type contextState struct {
|
||||
err error
|
||||
status StatusType
|
||||
mutex sync.Mutex
|
||||
err error
|
||||
status StatusType
|
||||
mutex sync.Mutex
|
||||
onChange func()
|
||||
}
|
||||
|
||||
// SetOnChange installs a callback fired after every successful Set. Used by
|
||||
// the daemon to wire the status recorder's notifyStateChange so any
|
||||
// state.Set in the connect/login paths pushes a fresh snapshot to
|
||||
// SubscribeStatus subscribers without each callsite having to opt in.
|
||||
// The callback runs outside the contextState mutex to avoid a lock-order
|
||||
// dependency with the recorder's stateChangeMux.
|
||||
func (c *contextState) SetOnChange(fn func()) {
|
||||
c.mutex.Lock()
|
||||
c.onChange = fn
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *contextState) Set(update StatusType) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.status = update
|
||||
c.err = nil
|
||||
cb := c.onChange
|
||||
c.mutex.Unlock()
|
||||
|
||||
if cb != nil {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *contextState) Status() (StatusType, error) {
|
||||
@@ -57,6 +74,17 @@ func (c *contextState) Status() (StatusType, error) {
|
||||
return c.status, nil
|
||||
}
|
||||
|
||||
// CurrentStatus returns the last status set via Set, ignoring any wrapped
|
||||
// error. Use when the status is needed for reporting purposes (e.g. the
|
||||
// status snapshot stream) and a transient wrapped error from a retry loop
|
||||
// shouldn't blank out the underlying status.
|
||||
func (c *contextState) CurrentStatus() StatusType {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *contextState) Wrap(err error) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
hostDNS := []netip.AddrPort{
|
||||
netip.MustParseAddrPort("9.9.9.9:53"),
|
||||
netip.MustParseAddrPort("149.112.112.112:53"),
|
||||
}
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -32,9 +32,6 @@
|
||||
</File>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
||||
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
|
||||
<?if $(var.ArchSuffix) = "amd64" ?>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
||||
<?endif ?>
|
||||
|
||||
<ServiceInstall
|
||||
Id="NetBirdService"
|
||||
@@ -62,8 +59,23 @@
|
||||
<Component Id="NetbirdAumidRegistry" Guid="*">
|
||||
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
|
||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
||||
<!-- Pre-seed the CLSID the Wails notifications service reads on
|
||||
first startup (notifications_windows.go:getGUID looks for
|
||||
the CustomActivator value under this key). Without this
|
||||
the service generates a fresh per-install UUID, which
|
||||
diverges from the ToastActivatorCLSID set on the Start
|
||||
Menu / Desktop shortcuts above and the COM activator
|
||||
never fires when a toast is clicked. -->
|
||||
<RegistryValue Name="CustomActivator" Type="string" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
||||
</RegistryKey>
|
||||
</Component>
|
||||
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
|
||||
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
|
||||
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
|
||||
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
|
||||
<RemoveRegistryValue Root="HKCU"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
|
||||
</Component>
|
||||
</StandardDirectory>
|
||||
|
||||
<StandardDirectory Id="CommonAppDataFolder">
|
||||
@@ -76,19 +88,67 @@
|
||||
</Directory>
|
||||
</StandardDirectory>
|
||||
|
||||
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
|
||||
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
|
||||
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
|
||||
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
|
||||
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
|
||||
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
|
||||
<RemoveRegistryValue Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
|
||||
</Component>
|
||||
|
||||
<ComponentGroup Id="NetbirdFilesComponent">
|
||||
<ComponentRef Id="NetbirdFiles" />
|
||||
<ComponentRef Id="NetbirdAumidRegistry" />
|
||||
<ComponentRef Id="NetbirdAutoStart" />
|
||||
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
|
||||
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
|
||||
</ComponentGroup>
|
||||
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" TerminateProcess="0" />
|
||||
|
||||
<!-- WebView2 evergreen runtime detection.
|
||||
Probe both the per-machine and per-user EdgeUpdate keys; if either
|
||||
reports a non-empty `pv` value the runtime is already installed
|
||||
and we skip the bootstrapper. -->
|
||||
<Property Id="WEBVIEW2_VERSION_HKLM">
|
||||
<RegistrySearch Id="WV2HKLM" Root="HKLM"
|
||||
Key="SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
Name="pv" Type="raw" Bitness="always64" />
|
||||
</Property>
|
||||
<Property Id="WEBVIEW2_VERSION_HKCU">
|
||||
<RegistrySearch Id="WV2HKCU" Root="HKCU"
|
||||
Key="Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
Name="pv" Type="raw" />
|
||||
</Property>
|
||||
|
||||
<!-- Embed the bootstrapper payload. Path is relative to the WiX
|
||||
working directory; sign-pipelines stages it next to client/
|
||||
via `wails3 generate webview2bootstrapper`. -->
|
||||
<Binary Id="WebView2Bootstrapper" SourceFile=".\client\MicrosoftEdgeWebview2Setup.exe" />
|
||||
|
||||
<CustomAction Id="InstallWebView2"
|
||||
BinaryRef="WebView2Bootstrapper"
|
||||
ExeCommand="/silent /install"
|
||||
Execute="deferred"
|
||||
Impersonate="no"
|
||||
Return="check" />
|
||||
|
||||
<InstallExecuteSequence>
|
||||
<Custom Action="InstallWebView2" Before="InstallFinalize"
|
||||
Condition="NOT WEBVIEW2_VERSION_HKLM AND NOT WEBVIEW2_VERSION_HKCU AND NOT REMOVE" />
|
||||
</InstallExecuteSequence>
|
||||
|
||||
<!-- Icons -->
|
||||
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" />
|
||||
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\build\windows\icon.ico" />
|
||||
<Property Id="ARPPRODUCTICON" Value="NetbirdIcon" />
|
||||
|
||||
</Package>
|
||||
|
||||
@@ -24,6 +24,12 @@ service DaemonService {
|
||||
// Status of the service.
|
||||
rpc Status(StatusRequest) returns (StatusResponse) {}
|
||||
|
||||
// SubscribeStatus pushes a fresh StatusResponse on connection state
|
||||
// changes (Connected / Disconnected / Connecting / address change /
|
||||
// peers list change). The first message on the stream is the current
|
||||
// snapshot, so a freshly-subscribed UI doesn't need to also call Status.
|
||||
rpc SubscribeStatus(StatusRequest) returns (stream StatusResponse) {}
|
||||
|
||||
// Down stops engine work in the daemon.
|
||||
rpc Down(DownRequest) returns (DownResponse) {}
|
||||
|
||||
@@ -109,6 +115,25 @@ service DaemonService {
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||
|
||||
// RequestExtendAuthSession initiates an SSO session-extension flow.
|
||||
// The daemon prepares a PKCE/device-code request against the IdP and
|
||||
// returns the verification URI; the UI is expected to open it. The flow
|
||||
// state is kept in the daemon until WaitExtendAuthSession completes it.
|
||||
rpc RequestExtendAuthSession(RequestExtendAuthSessionRequest) returns (RequestExtendAuthSessionResponse) {}
|
||||
|
||||
// WaitExtendAuthSession blocks until the user finishes the SSO step
|
||||
// started by RequestExtendAuthSession, then forwards the resulting JWT
|
||||
// to the management server's ExtendAuthSession RPC. Returns the new
|
||||
// session expiry deadline. The tunnel stays up the entire time.
|
||||
rpc WaitExtendAuthSession(WaitExtendAuthSessionRequest) returns (WaitExtendAuthSessionResponse) {}
|
||||
|
||||
// DismissSessionWarning records that the user clicked "Dismiss" on the
|
||||
// T-WarningLead interactive notification, suppressing the auto-opened
|
||||
// SessionAboutToExpire dialog that would otherwise fire at
|
||||
// T-FinalWarningLead for the current deadline. Idempotent and best-effort:
|
||||
// a missed call only means the fallback dialog will still appear.
|
||||
rpc DismissSessionWarning(DismissSessionWarningRequest) returns (DismissSessionWarningResponse) {}
|
||||
|
||||
// StartCPUProfile starts CPU profiling in the daemon
|
||||
rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {}
|
||||
|
||||
@@ -227,6 +252,12 @@ message UpRequest {
|
||||
optional string profileName = 1;
|
||||
optional string username = 2;
|
||||
reserved 3;
|
||||
// async instructs the daemon to start the connection attempt and return
|
||||
// immediately without waiting for the engine to become ready. Status updates
|
||||
// are delivered via the SubscribeStatus stream. When false (the default) the
|
||||
// RPC blocks until the engine is running or gives up, which is the behaviour
|
||||
// needed by the CLI.
|
||||
bool async = 4;
|
||||
}
|
||||
|
||||
message UpResponse {}
|
||||
@@ -244,6 +275,10 @@ message StatusResponse{
|
||||
FullStatus fullStatus = 2;
|
||||
// NetBird daemon version
|
||||
string daemonVersion = 3;
|
||||
// Absolute UTC instant at which the peer's SSO session expires.
|
||||
// Unset when the peer is not SSO-registered or login expiration is disabled.
|
||||
// The UI derives "warning active" from this value and its own clock.
|
||||
google.protobuf.Timestamp sessionExpiresAt = 4;
|
||||
}
|
||||
|
||||
message DownRequest {}
|
||||
@@ -798,6 +833,55 @@ message WaitJWTTokenResponse {
|
||||
int64 expiresIn = 3;
|
||||
}
|
||||
|
||||
// RequestExtendAuthSessionRequest kicks off the session-extension SSO flow.
|
||||
message RequestExtendAuthSessionRequest {
|
||||
// Optional OIDC login_hint (typically the user's email) to pre-fill the
|
||||
// IdP login form.
|
||||
optional string hint = 1;
|
||||
}
|
||||
|
||||
// RequestExtendAuthSessionResponse carries the verification URI the UI
|
||||
// should open in a browser. The daemon retains the flow state and resolves
|
||||
// it via WaitExtendAuthSession.
|
||||
message RequestExtendAuthSessionResponse {
|
||||
// verification URI for the user to open in the browser
|
||||
string verificationURI = 1;
|
||||
// complete verification URI (with embedded user code)
|
||||
string verificationURIComplete = 2;
|
||||
// user code to enter on verification URI (for device-code flows)
|
||||
string userCode = 3;
|
||||
// device code for matching the WaitExtendAuthSession call to this flow
|
||||
string deviceCode = 4;
|
||||
// expiration time in seconds for the device code / PKCE flow
|
||||
int64 expiresIn = 5;
|
||||
}
|
||||
|
||||
// WaitExtendAuthSessionRequest is sent by the UI after it opens the
|
||||
// verification URI. The daemon blocks on this call until the user
|
||||
// completes (or aborts) the SSO step.
|
||||
message WaitExtendAuthSessionRequest {
|
||||
// device code returned by RequestExtendAuthSession
|
||||
string deviceCode = 1;
|
||||
// user code for verification
|
||||
string userCode = 2;
|
||||
}
|
||||
|
||||
// WaitExtendAuthSessionResponse carries the refreshed deadline returned
|
||||
// by the management server. Unset when the management server reports the
|
||||
// peer is not eligible for session extension.
|
||||
message WaitExtendAuthSessionResponse {
|
||||
google.protobuf.Timestamp sessionExpiresAt = 1;
|
||||
}
|
||||
|
||||
// DismissSessionWarningRequest is sent by the UI when the user clicks
|
||||
// "Dismiss" on the T-WarningLead notification.
|
||||
message DismissSessionWarningRequest {}
|
||||
|
||||
// DismissSessionWarningResponse acknowledges the dismissal. Carries no
|
||||
// payload — the daemon's only obligation is to silence the upcoming
|
||||
// T-FinalWarningLead fallback for the current deadline.
|
||||
message DismissSessionWarningResponse {}
|
||||
|
||||
// StartCPUProfileRequest for starting CPU profiling
|
||||
message StartCPUProfileRequest {}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
@@ -67,6 +68,12 @@ type Server struct {
|
||||
logFile string
|
||||
|
||||
oauthAuthFlow oauthAuthFlow
|
||||
// extendAuthSessionFlow holds the pending PKCE flow created by
|
||||
// RequestExtendAuthSession until WaitExtendAuthSession resolves it.
|
||||
// Kept separate from oauthAuthFlow (which is reserved for the SSH
|
||||
// JWT path) so a concurrent SSH auth doesn't clobber the session
|
||||
// extend flow or vice versa.
|
||||
extendAuthSessionFlow *auth.PendingFlow
|
||||
|
||||
mutex sync.Mutex
|
||||
config *profilemanager.Config
|
||||
@@ -123,6 +130,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
captureEnabled: captureEnabled,
|
||||
networksDisabled: networksDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
extendAuthSessionFlow: auth.NewPendingFlow(),
|
||||
}
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
@@ -140,6 +148,15 @@ func (s *Server) Start() error {
|
||||
}
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
// Every contextState.Set in the connect/login/server paths must push a
|
||||
// SubscribeStatus snapshot, otherwise transitions that don't happen to
|
||||
// be accompanied by a Mark{Management,Signal,...} call (e.g. plain
|
||||
// StatusNeedsLogin after a PermissionDenied login, StatusLoginFailed
|
||||
// after OAuth init failure, StatusIdle in the Login defer) leave the
|
||||
// UI stuck on the previous status until the next unrelated peer event.
|
||||
// Binding the recorder here means new state.Set callsites don't have
|
||||
// to opt in individually.
|
||||
state.SetOnChange(s.statusRecorder.NotifyStateChange)
|
||||
|
||||
if err := handlePanicLog(); err != nil {
|
||||
log.Warnf("failed to redirect stderr: %v", err)
|
||||
@@ -220,10 +237,20 @@ func (s *Server) Start() error {
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
// close(giveUpChan) MUST run on every exit path (DisableAutoConnect
|
||||
// return, backoff.Retry return, panic) — Down() blocks for up to 5s
|
||||
// waiting on this signal before flipping the state to Idle, and a
|
||||
// missed close leaves Down() always hitting the timeout. The signal
|
||||
// fires AFTER clientRunning=false is committed under the mutex so a
|
||||
// Down/Up racing with the goroutine exit never observes a half-state
|
||||
// (chan closed but clientRunning still true).
|
||||
defer func() {
|
||||
s.mutex.Lock()
|
||||
s.clientRunning = false
|
||||
s.mutex.Unlock()
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}()
|
||||
|
||||
if s.config.DisableAutoConnect {
|
||||
@@ -258,6 +285,15 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
runOperation := func() error {
|
||||
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
|
||||
if err != nil {
|
||||
// PermissionDenied means the daemon transitioned to NeedsLogin
|
||||
// inside connect(). Without backoff.Permanent the outer retry
|
||||
// re-enters connect(), which resets the state to Connecting and
|
||||
// makes the tray flicker between NeedsLogin and Connecting until
|
||||
// the user logs in. Stop retrying and let the state stick.
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
log.Debugf("run client connection exited with PermissionDenied, waiting for login")
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
||||
return err
|
||||
}
|
||||
@@ -269,10 +305,6 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
@@ -341,9 +373,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
}
|
||||
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
if *msg.OptionalPreSharedKey != "" {
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
}
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
}
|
||||
|
||||
if msg.CleanDNSLabels {
|
||||
@@ -569,8 +599,35 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
return &proto.LoginResponse{}, nil
|
||||
}
|
||||
|
||||
// WaitSSOLogin uses the userCode to validate the TokenInfo and
|
||||
// waits for the user to continue with the login on a browser
|
||||
// WaitSSOLogin validates the supplied userCode against the in-flight OAuth
|
||||
// device/PKCE flow and blocks until the user finishes the browser leg.
|
||||
//
|
||||
// State transitions on exit:
|
||||
//
|
||||
// ┌──────────────────────────────────────────┬──────────────────────────────────┐
|
||||
// │ Outcome │ contextState │
|
||||
// ├──────────────────────────────────────────┼──────────────────────────────────┤
|
||||
// │ Success → loginAttempt → Connected │ StatusConnected (loginAttempt) │
|
||||
// │ Success → loginAttempt → still-NeedsLogin│ StatusNeedsLogin (loginAttempt) │
|
||||
// │ Success → loginAttempt error │ StatusLoginFailed (loginAttempt) │
|
||||
// │ UserCode mismatch │ StatusLoginFailed │
|
||||
// │ WaitToken: context.Canceled (external │ defer runs: status untouched if │
|
||||
// │ abort — profile switch invokes │ already NeedsLogin/LoginFailed,│
|
||||
// │ actCancel/waitCancel, app quit, │ else StatusIdle. Keeps the │
|
||||
// │ another WaitSSOLogin started) │ cancel from leaking as a │
|
||||
// │ │ spurious LoginFailed on the │
|
||||
// │ │ next profile's Up. │
|
||||
// │ WaitToken: context.DeadlineExceeded │ StatusNeedsLogin │
|
||||
// │ (OAuth device-code window expired │ (retryable; the UI's "Connect" │
|
||||
// │ while waiting on the browser leg) │ re-enters the Login flow) │
|
||||
// │ WaitToken: any other error │ StatusLoginFailed │
|
||||
// │ (access_denied, expired_token, HTTP │ (genuine auth/IO failure; │
|
||||
// │ failure, token validation rejection) │ surfaced verbatim to caller) │
|
||||
// └──────────────────────────────────────────┴──────────────────────────────────┘
|
||||
//
|
||||
// The defer at the top of the function applies the Idle fallback so callers
|
||||
// that bypass the explicit Set calls (the Canceled branch above, the success
|
||||
// path before loginAttempt) still land on a sensible terminal status.
|
||||
func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLoginRequest) (*proto.WaitSSOLoginResponse, error) {
|
||||
s.mutex.Lock()
|
||||
if s.actCancel != nil {
|
||||
@@ -630,7 +687,21 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow.expiresAt = time.Now()
|
||||
s.mutex.Unlock()
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
// External abort (profile switch, app quit, another
|
||||
// WaitSSOLogin started). Not a login failure — let the
|
||||
// top-level defer fall through to StatusIdle so the next
|
||||
// flow starts from a clean state.
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
// OAuth device-code window expired with no user action.
|
||||
// Retryable — leave the daemon in NeedsLogin so the UI
|
||||
// keeps the Login affordance instead of reading as a
|
||||
// hard failure.
|
||||
state.Set(internal.StatusNeedsLogin)
|
||||
default:
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
}
|
||||
log.Errorf("waiting for browser login failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
@@ -745,6 +816,9 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
|
||||
s.mutex.Unlock()
|
||||
if msg.GetAsync() {
|
||||
return &proto.UpResponse{}, nil
|
||||
}
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
|
||||
@@ -844,23 +918,37 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
|
||||
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
|
||||
// The giveUpChan is closed at the end of connectWithRetryRuns.
|
||||
// The giveUpChan is closed by the goroutine's deferred cleanup (see
|
||||
// connectWithRetryRuns) on every exit path. A timeout here typically
|
||||
// means the goroutine is still wedged inside a slow teardown step.
|
||||
if giveUpChan != nil {
|
||||
select {
|
||||
case <-giveUpChan:
|
||||
log.Debugf("client goroutine finished successfully")
|
||||
log.Debugf("client goroutine finished, giveUpChan closed")
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
|
||||
}
|
||||
}
|
||||
|
||||
// Set Idle only after the retry goroutine has exited (or timed out).
|
||||
// Setting it earlier races with the goroutine's own Set(StatusConnecting)
|
||||
// at the top of each retry attempt, which would leave the snapshot
|
||||
// stuck at Connecting long after the user asked to disconnect.
|
||||
internal.CtxGetState(s.rootCtx).Set(internal.StatusIdle)
|
||||
|
||||
// Clear stale management/signal errors so the next Up() (typically for a
|
||||
// different profile) starts with a clean status snapshot. Without this,
|
||||
// a managementError left over from a LoginFailed cycle persists in the
|
||||
// statusRecorder and appears in the new profile's initial
|
||||
// SubscribeStatus snapshot, making the new profile look like it also
|
||||
// failed to log in.
|
||||
s.statusRecorder.MarkManagementDisconnected(nil)
|
||||
s.statusRecorder.MarkSignalDisconnected(nil)
|
||||
|
||||
return &proto.DownResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -1114,9 +1202,23 @@ func (s *Server) Status(
|
||||
}
|
||||
}
|
||||
|
||||
status, err := internal.CtxGetState(s.rootCtx).Status()
|
||||
return s.buildStatusResponse(msg)
|
||||
}
|
||||
|
||||
// buildStatusResponse composes a StatusResponse from the current daemon
|
||||
// state. Shared between the unary Status RPC and the SubscribeStatus
|
||||
// stream so both paths return identical snapshots.
|
||||
func (s *Server) buildStatusResponse(msg *proto.StatusRequest) (*proto.StatusResponse, error) {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// state.Status() blanks the status when err is set (e.g. management
|
||||
// retry loop wrapped a connection error). The underlying status is
|
||||
// still meaningful and the failure is already surfaced via
|
||||
// FullStatus.ManagementState.Error, so don't propagate err — that
|
||||
// would tear down the SubscribeStatus stream and cause the UI to
|
||||
// mark the daemon as unreachable on every retry.
|
||||
status = state.CurrentStatus()
|
||||
}
|
||||
|
||||
if status == internal.StatusNeedsLogin && s.isSessionActive.Load() {
|
||||
@@ -1127,6 +1229,10 @@ func (s *Server) Status(
|
||||
|
||||
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
|
||||
|
||||
if deadline := s.statusRecorder.GetSessionExpiresAt(); !deadline.IsZero() {
|
||||
statusResponse.SessionExpiresAt = timestamppb.New(deadline)
|
||||
}
|
||||
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
@@ -1356,6 +1462,131 @@ func (s *Server) WaitJWTToken(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequestExtendAuthSession initiates the SSO session-extension flow and
|
||||
// returns the verification URI the UI should open. The flow state is held
|
||||
// in s.extendAuthSessionFlow until WaitExtendAuthSession resolves it.
|
||||
func (s *Server) RequestExtendAuthSession(
|
||||
ctx context.Context,
|
||||
msg *proto.RequestExtendAuthSessionRequest,
|
||||
) (*proto.RequestExtendAuthSessionResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
config := s.config
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if config == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
|
||||
}
|
||||
if connectClient == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not running")
|
||||
}
|
||||
|
||||
hint := ""
|
||||
if msg.Hint != nil {
|
||||
hint = *msg.Hint
|
||||
}
|
||||
if hint == "" {
|
||||
hint = profilemanager.GetLoginHint()
|
||||
}
|
||||
|
||||
isDesktop := isUnixRunningDesktop()
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
|
||||
}
|
||||
|
||||
s.extendAuthSessionFlow.Set(oAuthFlow, authInfo)
|
||||
|
||||
return &proto.RequestExtendAuthSessionResponse{
|
||||
VerificationURI: authInfo.VerificationURI,
|
||||
VerificationURIComplete: authInfo.VerificationURIComplete,
|
||||
UserCode: authInfo.UserCode,
|
||||
DeviceCode: authInfo.DeviceCode,
|
||||
ExpiresIn: int64(authInfo.ExpiresIn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WaitExtendAuthSession blocks until the user completes the SSO step
|
||||
// initiated by RequestExtendAuthSession, then forwards the resulting JWT
|
||||
// to the management server's ExtendAuthSession RPC. The returned deadline
|
||||
// is also applied locally via the engine so SubscribeStatus consumers see
|
||||
// the refreshed state.
|
||||
func (s *Server) WaitExtendAuthSession(
|
||||
ctx context.Context,
|
||||
req *proto.WaitExtendAuthSessionRequest,
|
||||
) (*proto.WaitExtendAuthSessionResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
oAuthFlow, authInfo, ok := s.extendAuthSessionFlow.Get()
|
||||
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if !ok || authInfo.DeviceCode != req.DeviceCode {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active extend-session flow")
|
||||
}
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to obtain JWT token: %v", err)
|
||||
}
|
||||
|
||||
// Clear pending flow before talking to mgm so a retry can re-initiate.
|
||||
s.extendAuthSessionFlow.Clear()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not running")
|
||||
}
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "engine is not initialised")
|
||||
}
|
||||
|
||||
deadline, err := engine.ExtendAuthSession(ctx, tokenInfo.GetTokenToUse())
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "management ExtendAuthSession failed: %v", err)
|
||||
}
|
||||
|
||||
resp := &proto.WaitExtendAuthSessionResponse{}
|
||||
if !deadline.IsZero() {
|
||||
resp.SessionExpiresAt = timestamppb.New(deadline)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// DismissSessionWarning forwards the user's "Dismiss" click on the
|
||||
// T-WarningLead notification down to the engine's sessionWatcher so the
|
||||
// T-FinalWarningLead fallback is suppressed for the current deadline.
|
||||
// Best-effort: when the client/engine is not yet running the call is a
|
||||
// successful no-op (the watcher has no deadline to dismiss anyway).
|
||||
func (s *Server) DismissSessionWarning(
|
||||
_ context.Context,
|
||||
_ *proto.DismissSessionWarningRequest,
|
||||
) (*proto.DismissSessionWarningResponse, error) {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
if connectClient == nil {
|
||||
return &proto.DismissSessionWarningResponse{}, nil
|
||||
}
|
||||
if engine := connectClient.Engine(); engine != nil {
|
||||
engine.DismissSessionWarning()
|
||||
}
|
||||
return &proto.DismissSessionWarningResponse{}, nil
|
||||
}
|
||||
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy.
|
||||
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
|
||||
s.mutex.Lock()
|
||||
|
||||
57
client/server/status_stream.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// SubscribeStatus pushes a fresh StatusResponse on every connection state
|
||||
// change. The first message is the current snapshot, so a re-subscribing
|
||||
// client doesn't need to also call Status. Subsequent messages fire when
|
||||
// the peer recorder reports any of: connected/disconnected/connecting,
|
||||
// management or signal flip, address change, or peers list change.
|
||||
//
|
||||
// The change channel coalesces bursts to a single tick. If the consumer
|
||||
// is slow the daemon drops extras (not blocks), and the next snapshot
|
||||
// the consumer pulls already reflects everything.
|
||||
func (s *Server) SubscribeStatus(req *proto.StatusRequest, stream proto.DaemonService_SubscribeStatusServer) error {
|
||||
subID, ch := s.statusRecorder.SubscribeToStateChanges()
|
||||
defer func() {
|
||||
s.statusRecorder.UnsubscribeFromStateChanges(subID)
|
||||
log.Debug("client unsubscribed from status updates")
|
||||
}()
|
||||
|
||||
log.Debug("client subscribed to status updates")
|
||||
|
||||
if err := s.sendStatusSnapshot(req, stream); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := s.sendStatusSnapshot(req, stream); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-stream.Context().Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sendStatusSnapshot(req *proto.StatusRequest, stream proto.DaemonService_SubscribeStatusServer) error {
|
||||
resp, err := s.buildStatusResponse(req)
|
||||
if err != nil {
|
||||
log.Warnf("build status snapshot for stream: %v", err)
|
||||
return err
|
||||
}
|
||||
if err := stream.Send(resp); err != nil {
|
||||
log.Warnf("send status snapshot to stream: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -55,6 +55,10 @@ type ConvertOptions struct {
|
||||
IPsFilter map[string]struct{}
|
||||
ConnectionTypeFilter string
|
||||
ProfileName string
|
||||
// SessionExpiresAt is the absolute UTC instant at which the peer's SSO
|
||||
// session expires. Zero when the peer is not SSO-tracked or login
|
||||
// expiration is disabled. Sourced from StatusResponse.SessionExpiresAt.
|
||||
SessionExpiresAt time.Time
|
||||
}
|
||||
|
||||
type PeerStateDetailOutput struct {
|
||||
@@ -153,6 +157,11 @@ type OutputOverview struct {
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||
// SessionExpiresAt is the absolute UTC instant at which the peer's SSO
|
||||
// session expires. nil when the peer is not SSO-tracked or login
|
||||
// expiration is disabled. Pointer (rather than zero-value time.Time) so
|
||||
// JSON / YAML omit the field entirely with `,omitempty`.
|
||||
SessionExpiresAt *time.Time `json:"sessionExpiresAt,omitempty" yaml:"sessionExpiresAt,omitempty"`
|
||||
}
|
||||
|
||||
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
||||
@@ -198,6 +207,10 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
ProfileName: opts.ProfileName,
|
||||
SSHServerState: sshServerOverview,
|
||||
}
|
||||
if !opts.SessionExpiresAt.IsZero() {
|
||||
t := opts.SessionExpiresAt
|
||||
overview.SessionExpiresAt = &t
|
||||
}
|
||||
|
||||
if opts.Anonymize {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
@@ -535,6 +548,15 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
var sessionExpiryString string
|
||||
if o.SessionExpiresAt != nil && !o.SessionExpiresAt.IsZero() {
|
||||
sessionExpiryString = fmt.Sprintf(
|
||||
"Session expires: %s (in %s)\n",
|
||||
o.SessionExpiresAt.Format(time.RFC3339),
|
||||
FormatRemainingDuration(time.Until(*o.SessionExpiresAt)),
|
||||
)
|
||||
}
|
||||
|
||||
var forwardingRulesString string
|
||||
if o.NumberOfForwardingRules > 0 {
|
||||
forwardingRulesString = fmt.Sprintf("Forwarding rules: %d\n", o.NumberOfForwardingRules)
|
||||
@@ -565,6 +587,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"SSH Server: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"%s"+
|
||||
"%s"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
o.DaemonVersion,
|
||||
@@ -583,6 +606,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
sshServerStatus,
|
||||
networks,
|
||||
forwardingRulesString,
|
||||
sessionExpiryString,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
@@ -996,3 +1020,57 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
|
||||
}
|
||||
}
|
||||
|
||||
// FormatRemainingDuration renders a time.Duration for the "Session expires"
|
||||
// line. Examples: "2h 15m", "47m 12s", "8s", "expired 3m ago".
|
||||
//
|
||||
// Granularity drops to seconds only under a minute, otherwise minutes are
|
||||
// the smallest unit shown — sub-minute precision is noise for a deadline
|
||||
// that's hours or days out.
|
||||
func FormatRemainingDuration(d time.Duration) string {
|
||||
if d <= 0 {
|
||||
return "expired " + HumaniseDuration(-d) + " ago"
|
||||
}
|
||||
return HumaniseDuration(d)
|
||||
}
|
||||
|
||||
// HumaniseDuration renders a positive duration in compact form (e.g.
|
||||
// "2h 15m", "47m", "8s"). Exposed alongside FormatRemainingDuration so
|
||||
// callers that don't need the "expired … ago" wording can format
|
||||
// positive durations directly.
|
||||
func HumaniseDuration(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
s := int(d.Round(time.Second).Seconds())
|
||||
if s < 1 {
|
||||
s = 1
|
||||
}
|
||||
return fmt.Sprintf("%ds", s)
|
||||
}
|
||||
|
||||
const (
|
||||
day = 24 * time.Hour
|
||||
hour = time.Hour
|
||||
minute = time.Minute
|
||||
)
|
||||
|
||||
days := int64(d / day)
|
||||
d -= time.Duration(days) * day
|
||||
hours := int64(d / hour)
|
||||
d -= time.Duration(hours) * hour
|
||||
minutes := int64(d / minute)
|
||||
|
||||
switch {
|
||||
case days > 0:
|
||||
if hours == 0 {
|
||||
return fmt.Sprintf("%dd", days)
|
||||
}
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
case hours > 0:
|
||||
if minutes == 0 {
|
||||
return fmt.Sprintf("%dh", hours)
|
||||
}
|
||||
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||||
default:
|
||||
return fmt.Sprintf("%dm", minutes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -641,3 +641,50 @@ func TestTimeAgo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHumaniseDuration(t *testing.T) {
|
||||
cases := []struct {
|
||||
in time.Duration
|
||||
want string
|
||||
}{
|
||||
{0, "1s"},
|
||||
{500 * time.Millisecond, "1s"},
|
||||
{8 * time.Second, "8s"},
|
||||
{59 * time.Second, "59s"},
|
||||
{time.Minute, "1m"},
|
||||
{47*time.Minute + 12*time.Second, "47m"},
|
||||
{time.Hour, "1h"},
|
||||
{2*time.Hour + 15*time.Minute, "2h 15m"},
|
||||
{2 * time.Hour, "2h"},
|
||||
{24 * time.Hour, "1d"},
|
||||
{2*24*time.Hour + 3*time.Hour, "2d 3h"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := HumaniseDuration(tc.in)
|
||||
assert.Equal(t, tc.want, got, "input %s", tc.in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatRemainingDuration_Expired(t *testing.T) {
|
||||
assert.Equal(t, "expired 3m ago", FormatRemainingDuration(-3*time.Minute))
|
||||
assert.Equal(t, "expired 1s ago", FormatRemainingDuration(-500*time.Millisecond))
|
||||
}
|
||||
|
||||
func TestSessionExpiresLineRendered(t *testing.T) {
|
||||
in := overview // copy of the package-level fixture
|
||||
deadline := time.Now().Add(2*time.Hour + 30*time.Minute).UTC()
|
||||
in.SessionExpiresAt = &deadline
|
||||
|
||||
out := in.GeneralSummary(false, false, false, false)
|
||||
assert.Contains(t, out, "Session expires: ")
|
||||
assert.Contains(t, out, deadline.Format(time.RFC3339))
|
||||
// 2h 30m drifts to "2h 29m" within 60s — match the family prefix.
|
||||
assert.Contains(t, out, "(in 2h ")
|
||||
}
|
||||
|
||||
func TestSessionExpiresLineOmittedWhenNil(t *testing.T) {
|
||||
in := overview
|
||||
in.SessionExpiresAt = nil
|
||||
out := in.GeneralSummary(false, false, false, false)
|
||||
assert.NotContains(t, out, "Session expires")
|
||||
}
|
||||
|
||||
8
client/ui/.gitignore
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
.task
|
||||
bin
|
||||
frontend/dist
|
||||
frontend/node_modules
|
||||
frontend/bindings
|
||||
frontend/.vite
|
||||
build/linux/appimage/build
|
||||
build/windows/nsis/MicrosoftEdgeWebview2Setup.exe
|
||||
152
client/ui/CLAUDE.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# NetBird Wails UI — Working Notes
|
||||
|
||||
This is the Wails v3 desktop UI for NetBird. Go services live in `services/`; the React/TS frontend lives in `frontend/`; bindings between them are generated under `frontend/bindings/`.
|
||||
|
||||
> **Keep these notes current.** When working in this directory with Claude, update this file (and `frontend/CLAUDE.md` for frontend-only changes) whenever you add a service, change an event name, shift a convention, rename a key directory, or land any other change that future-you would want to know about before reading the code. The goal is that a cold-start agent can orient itself from these notes without re-deriving the codebase.
|
||||
|
||||
## Layout
|
||||
|
||||
### Go (top-level package `main`)
|
||||
- `main.go` — app entry. Builds the shared gRPC `Conn`, constructs services, registers them with Wails, creates the main webview window, then starts (in order) the Linux SNI watcher → tray → `peers.Watch` → `app.Run`. CLI flags: `--daemon-addr`, `--log-file` (repeatable; first user-provided value drops the seeded `console` default), `--log-level` (`trace|debug|info|warn|error`, default `info`).
|
||||
- `tray.go` — `Tray` struct + menu. Subscribes to `EventStatus`, `EventSystem`, `EventUpdateAvailable`, `EventUpdateProgress`. Owns per-status icon/dot, Profiles submenu, Connect/Disconnect swap, About → Update, session-expired toast.
|
||||
- `tray_linux.go` — `init()` sets `WEBKIT_DISABLE_DMABUF_RENDERER=1` to avoid the blank-white window on VMs / minimal WMs.
|
||||
- `tray_watcher_linux.go`, `xembed_host_linux.go`, `xembed_tray_linux.{c,h}` — in-process SNI watcher + XEmbed bridge for minimal WMs. See `LINUX-TRAY.md`.
|
||||
- `signal_unix.go` / `signal_windows.go` — `listenForShowSignal`. Unix uses SIGUSR1; Windows uses a named event `Global\NetBirdQuickActionsTriggerEvent`. Mirrors the legacy Fyne UI's external-trigger contract so the installer / CLI keep working.
|
||||
- `grpc.go` — lazy, mutex-protected gRPC `Conn` shared by every service. `DaemonAddr()`: `unix:///var/run/netbird.sock` on Linux/macOS, `tcp://127.0.0.1:41731` on Windows.
|
||||
- `icons.go` — `//go:embed` tray/window PNGs. macOS uses template variants (`*-macos.png`); Linux ships light + dark PNGs; Windows reuses the light PNG (multi-frame `.ico` never redrew on Wails3's `NIM_MODIFY`).
|
||||
|
||||
### Wails services (`services/*.go`)
|
||||
Each service is registered via `app.RegisterService(application.NewService(svc))`. Every method becomes a TS function in `frontend/bindings/.../services/`. Frontend-facing details (TS signatures, push events, models) are in `frontend/WAILS-API.md`. After editing any `services/*.go` or the proto, regenerate with `wails3 generate bindings -clean=true -ts` (or `pnpm bindings` from `frontend/`). `frontend/bindings/**` is gitignored.
|
||||
|
||||
For frontend-side conventions (routing, providers, contexts) see `frontend/CLAUDE.md`.
|
||||
|
||||
## Services rundown
|
||||
|
||||
All services live in `services/` and assume a build tag `!android && !ios && !freebsd && !js`. Each takes a shared `DaemonConn` (`conn.go`) and is registered in `main.go`.
|
||||
|
||||
| Service | File | Responsibility |
|
||||
|---|---|---|
|
||||
| `Connection` | `connection.go` | `Login` / `WaitSSOLogin` / `Up` / `Down` / `Logout` / `OpenURL`. `Up` is always async (`Async: true`); status flows back through `Peers`. `Login` Down-resets the daemon first to dislodge a stale WaitSSOLogin. `OpenURL` honors `$BROWSER`. |
|
||||
| `Settings` | `settings.go` | `GetConfig` / `SetConfig` (partial update — pointer fields are sent, nil fields preserved) / `GetFeatures` (operator-disabled UI surfaces). |
|
||||
| `Profiles` | `profile.go` | `Username` / `List` / `GetActive` / `Switch` / `Add` / `Remove`. `List` populates `Email` from the **user-side** state file (`profilemanager.NewProfileManager().GetProfileState`) — the daemon runs as root and can't read it. |
|
||||
| `ProfileSwitcher` | `profileswitcher.go` | `SwitchActive` — the single entry point both tray and frontend should use for profile flips. Applies the reconnect policy (see "Profile switching" below), mirrors the daemon switch into the user-side `profilemanager`, drives optimistic feedback via `Peers.BeginProfileSwitch`. |
|
||||
| `Peers` | `peers.go` | Daemon status snapshot + two long-running streams (`SubscribeStatus` → `EventStatus`, `SubscribeEvents` → `EventSystem`). Emits synthetic `StatusDaemonUnavailable` when the socket is unreachable. Owns the profile-switch suppression filter (`BeginProfileSwitch` / `CancelProfileSwitch` / `shouldSuppress`). Fan-outs update metadata into dedicated `EventUpdateAvailable` / `EventUpdateProgress` events. |
|
||||
| `Networks` | `network.go` | `List` / `Select` / `Deselect` of routed networks. |
|
||||
| `Forwarding` | `forwarding.go` | `List` exposed/forwarded services from the daemon's reverse-proxy table. |
|
||||
| `Debug` | `debug.go` | `Bundle` (debug bundle creation + optional upload) / `Get|SetLogLevel` / `RevealFile` (cross-platform "show in file manager"). |
|
||||
| `Update` | `update.go` | `GetState` / `Trigger` (enforced installer) / `GetInstallerResult` / `Quit`. The install-progress UI lives in its own auxiliary window (`/#/install-progress`), opened by `WindowManager.OpenInstallProgress` — the daemon goes unreachable mid-install so it can't be inside the main window. |
|
||||
| `WindowManager` | `windowmanager.go` | `OpenSettings(tab)` / `OpenBrowserLogin(uri)` / `CloseBrowserLogin` / `OpenSessionExpired` / `OpenSessionAboutToExpire(seconds)` / `OpenInstallProgress(version)` / `CloseInstallProgress`. `OpenSettings("")` opens the General tab; pass a tab id (e.g. `"profiles"`) to deep-link, encoded as `?tab=…` in the start URL. `OpenInstallProgress` is `AlwaysOnTop` and hides every other visible window for the duration of the install (restored on close). Auxiliary windows are created on first open and **destroyed** on close (Wails-recommended singleton pattern; prevents the macOS dock-reopen from resurrecting hidden windows). |
|
||||
| `I18n` | `i18n.go` | Thin facade over `i18n.Bundle`. `Languages()` returns the shipped locales (`_index.json`); `Bundle(code)` returns the full key→text map for one language so the React layer can drive its own translation library. |
|
||||
| `Preferences` | `preferences.go` | Thin facade over `preferences.Store`. `Get()` returns `{language}`; `SetLanguage(code)` validates against `i18n.Bundle.HasLanguage`, persists, and broadcasts `netbird:preferences:changed`. |
|
||||
|
||||
`DaemonConn` is defined in `services/conn.go`; `ptrStr` (string-to-*string helper for proto pointer fields) lives there too.
|
||||
|
||||
## Daemon proto
|
||||
- Proto source: `../proto/daemon.proto`. Generated Go in `../proto/*.pb.go`.
|
||||
- Regen: `cd ../proto && protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative daemon.proto`
|
||||
- Pinned versions (see `daemon.pb.go` header): `protoc v7.34.1`, `protoc-gen-go v1.36.6`. CI's `proto-version-check` workflow fails on mismatch.
|
||||
- After proto regen, also regen Wails bindings so the TS layer picks up new fields.
|
||||
|
||||
## Events bus
|
||||
|
||||
`main.go` registers five typed events for the frontend: `netbird:status` (`Status`), `netbird:event` (`SystemEvent`), `netbird:profile:changed` (`ProfileRef`), `netbird:update:available` (`UpdateAvailable`), `netbird:update:progress` (`UpdateProgress`). `netbird:profile:changed` fires from `ProfileSwitcher.SwitchActive` after a successful daemon-side switch — both the React `ProfileContext` and the tray subscribe so a flip driven from one surface paints in the others (the daemon itself does not emit a profile event). Plus three plain-string events:
|
||||
|
||||
- `EventTriggerLogin = "trigger-login"` — tray asking the frontend's `startLogin()` to begin an SSO flow. The tray does **not** show the main window when emitting — the hidden webview is alive and subscribed, so `startLogin` runs and the only visible surface is the BrowserLogin popup it opens.
|
||||
- `EventBrowserLoginCancel = "browser-login:cancel"` — the `BrowserLogin` window's Cancel button or red-X close. `startLogin()` listens and tears down the daemon's pending `WaitSSOLogin`.
|
||||
- `preferences.EventPreferencesChanged = "netbird:preferences:changed"` — emitted after every successful `SetLanguage` (payload `{language}`). Both the tray menu rebuild and the React `i18next.changeLanguage` subscribe so a flip from any window paints everywhere.
|
||||
|
||||
Daemon connection status strings (`services/peers.go`) mirror `internal.Status*` in `client/internal/state.go`: `Connected`, `Connecting`, `Idle`, `NeedsLogin`, `LoginFailed`, `SessionExpired`, plus the synthetic `DaemonUnavailable` emitted by `Peers` when the socket is unreachable.
|
||||
|
||||
## Profile switching
|
||||
|
||||
`services/profileswitcher.go` is the single source of truth for the reconnect policy. Both the tray (`tray.go switchProfile`) and the frontend (via `modules/profile/ProfileContext.tsx`'s `switchProfile`, which `modules/settings/SettingsProfiles.tsx` and the header `ProfileDropdown` go through) call `ProfileSwitcher.SwitchActive`; identical inputs give identical state transitions.
|
||||
|
||||
Reconnect policy (driven by `prevStatus` from `Peers.Get`):
|
||||
|
||||
| Previous status | Action | Optimistic UI | Suppressed events until new flow begins |
|
||||
|---|---|---|---|
|
||||
| Connected | Switch + Down + Up | Connecting (synthetic) | Connected, Idle |
|
||||
| Connecting | Switch + Down + Up | Connecting (unchanged) | Connected, Idle |
|
||||
| NeedsLogin / LoginFailed / SessionExpired | Switch + Down | (no change) | — |
|
||||
| Idle | Switch only | (no change) | — |
|
||||
|
||||
Only Connected/Connecting trigger `Peers.BeginProfileSwitch`. That:
|
||||
1. Sets a 30s `switchInProgress` guard.
|
||||
2. Emits a synthetic `Status{Status: StatusConnecting}` so both tray and React paint immediately.
|
||||
3. Tells `statusStreamLoop` to drop the daemon's stale Connected updates (peer count drops as the engine tears down) and the transient Idle in between Down and the new Up.
|
||||
|
||||
`shouldSuppress` releases the guard as soon as a status that signals the new flow began arrives:
|
||||
- **Suppressed**: Connected, Idle
|
||||
- **Pass through and clear**: Connecting / NeedsLogin / LoginFailed / SessionExpired / DaemonUnavailable
|
||||
- **Timeout fallback**: 30s elapsed → clear flag, emit normally.
|
||||
|
||||
`Peers.CancelProfileSwitch` aborts the suppression — called by `tray.go handleDisconnect` so the user's "Disconnect while Connecting" click paints through immediately.
|
||||
|
||||
Also: `ProfileSwitcher.SwitchActive` mirrors the daemon switch into the user-side `profilemanager` (`~/Library/Application Support/netbird/active_profile`). The CLI's `netbird up` reads this file and sends the resolved profile name back; if it diverges from the daemon's `/var/lib/netbird/active_profile.json`, the daemon silently flips back. Mirror failures don't abort the switch — surfaced as a warning.
|
||||
|
||||
## Auxiliary windows (`WindowManager`)
|
||||
|
||||
The main window is created up front in `main.go`. Auxiliary windows are created on demand by `services.WindowManager`:
|
||||
|
||||
- **Settings** (`/#/settings`) — opened from the header gear icon (`layouts/Header.tsx → WindowManager.OpenSettings("")`), the tray's Settings menu entry (`tray.go openSettings`), and the profile dropdown's "Manage Profiles" entry (`WindowManager.OpenSettings("profiles")`, which sets `?tab=profiles` in the start URL — `Settings.tsx` reads it via `useSearchParams`). The window hosts every settings tab — including **Profiles** (`SettingsProfiles.tsx`, `UserCircle` icon, sits between Security and SSH), which lists profiles in a table with Deregister/Delete in a per-row kebab and an Add Profile button. Both call sites go through `WindowManager` so the user sees the same dedicated frameless window from either trigger — the tray used to repurpose the main window via `SetURL("/#/settings")`, which replaced the main UI in place. Frameless-look (opaque macOS backdrop, hidden inset title bar), fixed 900×640, no resize, no minimise/maximise.
|
||||
- **BrowserLogin** (`/#/browser-login?uri=…`) — opened by the connection toggle's SSO flow (`layouts/ConnectionStatusSwitch.tsx`). 460×440, fixed size. The close button (red X) fires `EventBrowserLoginCancel` so the JS-side `startLogin()` can tear down the daemon's pending `WaitSSOLogin`. `WindowManager.CloseBrowserLogin` closes it programmatically when the flow completes.
|
||||
- **SessionExpired** (`/#/session-expired`) and **SessionAboutToExpire** (`/#/session-about-to-expire?seconds=<n>`) — opened by `WindowManager.OpenSessionExpired` / `OpenSessionAboutToExpire(seconds)`. 460×380, fixed size, `AlwaysOnTop: true` (the user can't miss them). The React-side buttons close the window via `WindowManager.CloseSession*` and (for Sign-in / Stay-connected) emit `EventTriggerLogin` so the main window's `startLogin()` orchestrator handles the SSO flow. Currently triggered only by the DEV-only "Development" Settings tab; daemon-status integration is a follow-up.
|
||||
- **InstallProgress** (`/#/install-progress?version=<v>`) — opened by `WindowManager.OpenInstallProgress(version)` from `ClientVersionContext` (force-install branch on `installing` flip, user-driven enforced branch from `triggerUpdate`). 360-wide auto-sized via `useAutoSizeWindow`, `AlwaysOnTop`. Owns its own polling loop against `Update.GetInstallerResult` with the 5-second daemon-down-grace (sustained gRPC failure = success → call `Update.Quit()`). Hides every other visible window on open (restored on close). The DEV-only "Development" tab has a "Show updating dialog" button that opens this window directly for preview.
|
||||
|
||||
All five auxiliary windows are **destroyed** on close (mutex-guarded singleton; `closing` hook nils the field). Destroying rather than hiding is deliberate — Wails' macOS dock-reopen handler resurrects hidden windows, which we don't want for auxiliaries.
|
||||
|
||||
The main window is **hidden** on close (the `WindowClosing` hook calls `e.Cancel(); window.Hide()`). The user reaches "really quit" through the tray → Quit menu entry.
|
||||
|
||||
## Localisation (i18n)
|
||||
|
||||
The locale tree under `frontend/src/i18n/locales/` is the single source of truth for both Go (tray, OS notifications) and React (every user-facing string). Layout: `_index.json` lists shipped languages (`code` / `displayName` / `englishName`); `<code>/common.json` per language. `en/common.json` must exist (the `Bundle` loader hard-fails without it); languages listed in `_index.json` without a bundle are skipped with a warning. Placeholders are single-braced (`"Install version {version}"`) — Go substitutes via `Bundle.Translate(lang, key, "name", value, ...)`; React uses i18next with `interpolation: { prefix: "{", suffix: "}" }`.
|
||||
|
||||
Adding a language: drop a `<code>/common.json`, append a row to `_index.json`, add the static import in `frontend/src/i18n/index.ts`, rebuild. Embed lives in `client/ui/main.go`'s `embed.FS`.
|
||||
|
||||
Package layout:
|
||||
- `client/ui/i18n/` — pure `LanguageCode` / `Language` / `Bundle` loader. No Wails / no daemon. Reads the tree from an `fs.FS` passed in by `main.go`.
|
||||
- `client/ui/preferences/` — `Store` persists `UIPreferences{language}` to `os.UserConfigDir()/netbird/ui-preferences.json` (per-OS-user, shared across daemon profiles). Validates against an injected `LanguageValidator` (`*i18n.Bundle`). No file → in-memory default `en`, persisted on first `SetLanguage`. Broadcasts via in-process pub/sub + optional Wails event emitter.
|
||||
- `services/i18n.go` + `services/preferences.go` — Wails facades. Preferences emits `netbird:preferences:changed` (payload `{language}`) on every `SetLanguage`.
|
||||
|
||||
Key conventions: `tray.*` / `notify.*` (Go-side), `common.* / connect.* / nav.* / profile.* / settings.* / update.* / browserLogin.* / sessionExpired.* / peers.*` (frontend). Keep keys stable — renames cascade everywhere.
|
||||
|
||||
## Linux tray support
|
||||
|
||||
The in-process `StatusNotifierWatcher` + XEmbed host that lets the tray work on minimal WMs is detailed in `LINUX-TRAY.md` (sibling). Touch that doc when modifying `tray_watcher_linux.go` / `xembed_host_linux.go` / `xembed_tray_linux.{c,h}`.
|
||||
|
||||
## Wails Dialogs (frontend, `@wailsio/runtime`)
|
||||
|
||||
API surface — `Dialogs.Info` / `Warning` / `Error` / `Question` / `OpenFile` / `SaveFile`, options shape, per-OS behaviour, and the Go-side frameless-window pattern — lives in `WAILS-DIALOGS.md` (sibling). The conventions for **when** to use a native dialog vs inline UI are in the "Conventions" section below.
|
||||
|
||||
## Conventions in this codebase
|
||||
|
||||
### Errors → native dialogs
|
||||
|
||||
User-actionable operation failures (config save, profile switch, debug bundle, update, etc.) surface via `Dialogs.Error` with an action-named title — "Save Settings Failed", "Switch Profile Failed", not "Error" / "Something went wrong". The dialog itself already says "Error" visually.
|
||||
|
||||
Confirmations use `Dialogs.Warning` with explicit `Buttons`. The promise resolves with the **button Label string**, not an index — pin the label into a variable before comparing (especially with i18n, where labels translate). Full API in `WAILS-DIALOGS.md`.
|
||||
|
||||
**Skip native dialogs** for: inline form validation (`Input.tsx`, URL-format checks — too heavy for keystroke feedback); transient link errors on the dashboard (flap in/out with daemon — use an inline indicator); "partial success" notes inside an otherwise-OK flow (e.g. "bundle saved but upload failed" stays inline). The install-progress window owns its own error UI in-place (timeout/canceled/failed phases) — no native dialog needed there.
|
||||
|
||||
### OS notifications
|
||||
|
||||
The tray uses Wails' built-in `notifications` service. One `notifications.NotificationService` is created in `main.go` and passed into `TrayServices.Notifier`. Notification IDs are prefixed for coalescing: `netbird-update-<version>`, `netbird-event-<id>`, `netbird-tray-error`, `netbird-session-expired`. Notifications are gated by the user's "Notifications" toggle (cached in `Tray.notificationsEnabled`, seeded from `Settings.GetConfig` at boot). `Severity == "critical"` events bypass the gate, mirroring the legacy Fyne `event.Manager`.
|
||||
|
||||
### Profile switching invariants
|
||||
|
||||
`ProfileSwitcher.SwitchActive` is the only switch path on the TS side — `ProfileContext.switchProfile` is the single TS wrapper, and `modules/settings/SettingsProfiles.tsx` + the header `ProfileDropdown` both go through it. The Go side captures `prevStatus`, drives the optimistic-Connecting paint via `Peers.BeginProfileSwitch`, mirrors into the user-side `profilemanager`, and conditionally fires Down/Up per the reconnect-policy table above.
|
||||
|
||||
**Never call `Connection.Up` on an Idle/NeedsLogin daemon** — the daemon's internal 50s `waitForUp` blocks until `DeadlineExceeded`. `Connection.Up` from the frontend is reserved for the explicit Connect button (`ConnectionStatusSwitch.connect`) and the post-SSO resume inside `startLogin`; the gating for profile-switch reconnects lives Go-side in `ProfileSwitcher.SwitchActive`.
|
||||
|
||||
## Build / dev tasks
|
||||
|
||||
`task dev` (Wails dev, live reload), `task build` (prod build for the current OS, dispatches to `build/{darwin,linux,windows}/Taskfile.yml`), `task build:server` / `run:server` / `build:docker` / `run:docker` (server-mode variants in `build/Taskfile.yml`). **No** `task generate:bindings` alias — run `wails3 generate bindings -clean=true -ts` directly from this directory. CLI flags + log-target semantics are documented in the `main.go` bullet under "Layout".
|
||||
|
||||
## Useful references
|
||||
- `WAILS-DIALOGS.md` (sibling) — full `@wailsio/runtime` `Dialogs` API + per-OS behaviour + frameless-window pattern.
|
||||
- `LINUX-TRAY.md` (sibling) — StatusNotifierWatcher + XEmbed host details.
|
||||
- `frontend/WAILS-API.md` — frontend-facing binding signatures and model shapes.
|
||||
- Wails v3 dialog docs: https://v3.wails.io/features/dialogs/message/ and https://v3.wails.io/features/dialogs/custom/ (may 403 from some clients).
|
||||
- Wails v3 multiple-windows guidance: https://v3.wails.io/learn/multiple-windows/
|
||||
- Authoritative TS signatures: `frontend/node_modules/@wailsio/runtime/types/dialogs.d.ts`.
|
||||
- Wails examples: https://github.com/wailsapp/wails/tree/master/v3/examples/dialogs
|
||||
8
client/ui/LINUX-TRAY.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Linux tray support (StatusNotifierWatcher + XEmbed)
|
||||
|
||||
Minimal WMs (Fluxbox, OpenBox, i3, dwm, vanilla GNOME without the AppIndicator extension) don't ship a `StatusNotifierWatcher`, so tray icons using libayatana-appindicator / freedesktop StatusNotifier silently fail. `main.go` calls `startStatusNotifierWatcher()` *before* `NewTray` so the Wails systray's `RegisterStatusNotifierItem` call hits the in-process watcher we control.
|
||||
|
||||
- `tray_watcher_linux.go` — owns `org.kde.StatusNotifierWatcher` on the session bus if no other process has it. Safe to call unconditionally.
|
||||
- `xembed_host_linux.go` + `xembed_tray_linux.{c,h}` — when an XEmbed tray (`_NET_SYSTEM_TRAY_S0`) is available, also start an in-process XEmbed host that bridges the SNI icon into the XEmbed tray. Reads `IconPixmap` over D-Bus, draws via cairo+X11, polls for clicks, fetches `com.canonical.dbusmenu.GetLayout` for the popup menu, fires `com.canonical.dbusmenu.Event` on click.
|
||||
|
||||
Build is gated on `linux && !386`; the 386 build (no cgo) and non-Linux builds use the `tray_watcher_other.go` no-op.
|
||||
100
client/ui/README.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# NetBird desktop UI (Wails3 + React)
|
||||
|
||||
Replaces `client/ui` (Fyne). One binary on Windows / macOS / Linux,
|
||||
talks to the NetBird daemon over gRPC, renders a React frontend in a
|
||||
WebView.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Go ≥ 1.25, Node ≥ 20, **pnpm** (`corepack enable && corepack prepare pnpm@latest --activate`)
|
||||
- `wails3` CLI: `go install github.com/wailsapp/wails/v3/cmd/wails3@latest`
|
||||
- `task`: `go install github.com/go-task/task/v3/cmd/task@latest`
|
||||
- A running NetBird daemon (default: `unix:///var/run/netbird.sock`,
|
||||
Windows `tcp://127.0.0.1:41731`)
|
||||
- Linux only: `libwebkit2gtk-4.1-dev`, `libgtk-3-dev`,
|
||||
`libayatana-appindicator3-dev`
|
||||
|
||||
## Develop without rebuilding
|
||||
|
||||
```bash
|
||||
cd client/ui
|
||||
task dev
|
||||
```
|
||||
|
||||
`task dev` runs Vite (port 9245) + the Go binary + a `*.go` watcher.
|
||||
Frontend edits hot-reload instantly. Go edits trigger a rebuild and
|
||||
relaunch. Pass daemon flags after `--`:
|
||||
|
||||
```bash
|
||||
task dev -- --daemon-addr=tcp://127.0.0.1:41731
|
||||
```
|
||||
|
||||
For pure UI work (no native window, fastest loop):
|
||||
|
||||
```bash
|
||||
cd frontend && pnpm dev
|
||||
```
|
||||
|
||||
## Production build
|
||||
|
||||
```bash
|
||||
task build
|
||||
```
|
||||
|
||||
Output in `bin/`. Frontend assets are embedded into the binary.
|
||||
|
||||
### Cross-compile Windows from Linux
|
||||
|
||||
Install the mingw-w64 toolchain once:
|
||||
|
||||
```bash
|
||||
sudo apt install gcc-mingw-w64-x86-64 # Debian/Ubuntu
|
||||
sudo dnf install mingw64-gcc # Fedora
|
||||
sudo pacman -S mingw-w64-gcc # Arch
|
||||
```
|
||||
|
||||
Then:
|
||||
|
||||
```bash
|
||||
CGO_ENABLED=1 task windows:build
|
||||
```
|
||||
|
||||
Produces `bin/netbird-ui.exe`. macOS cross-compile from Linux is not
|
||||
supported (signing and notarization need a real Mac).
|
||||
|
||||
### Windows console build (logs in the terminal)
|
||||
|
||||
Default `windows:build` links the binary as a Windows GUI app, which
|
||||
detaches from the launching console — `logrus` output, `fmt.Println`,
|
||||
and panics go nowhere visible. To debug tray/event/daemon issues:
|
||||
|
||||
```bash
|
||||
CGO_ENABLED=1 task windows:build:console
|
||||
```
|
||||
|
||||
Produces `bin/netbird-ui-console.exe`. Run it from `cmd.exe` /
|
||||
PowerShell / Windows Terminal and stdout/stderr land in that
|
||||
terminal. Same flag works on a native Windows build (drop the
|
||||
`CGO_ENABLED=1` if your toolchain already has it set).
|
||||
|
||||
## Regenerating bindings
|
||||
|
||||
When a Go service signature changes:
|
||||
|
||||
```bash
|
||||
wails3 generate bindings
|
||||
```
|
||||
|
||||
`task dev` does this automatically on `*.go` save.
|
||||
|
||||
## Tray icons
|
||||
|
||||
Source SVGs live in `assets/svg/` (state.svg + state-macos.svg). After editing
|
||||
any SVG, rasterize to the PNGs the Go side embeds:
|
||||
|
||||
```bash
|
||||
task common:generate:tray:icons
|
||||
```
|
||||
|
||||
Requires Inkscape. Commit the resulting `assets/*.png` files alongside the
|
||||
SVG change so CI doesn't need Inkscape installed.
|
||||
58
client/ui/Taskfile.yml
Normal file
@@ -0,0 +1,58 @@
|
||||
version: '3'
|
||||
|
||||
includes:
|
||||
common: ./build/Taskfile.yml
|
||||
windows: ./build/windows/Taskfile.yml
|
||||
darwin: ./build/darwin/Taskfile.yml
|
||||
linux: ./build/linux/Taskfile.yml
|
||||
|
||||
vars:
|
||||
APP_NAME: "netbird-ui"
|
||||
BIN_DIR: "bin"
|
||||
VITE_PORT: '{{.WAILS_VITE_PORT | default 9245}}'
|
||||
|
||||
tasks:
|
||||
build:
|
||||
summary: Builds the application
|
||||
cmds:
|
||||
- task: "{{OS}}:build"
|
||||
|
||||
package:
|
||||
summary: Packages a production build of the application
|
||||
cmds:
|
||||
- task: "{{OS}}:package"
|
||||
|
||||
run:
|
||||
summary: Runs the application
|
||||
cmds:
|
||||
- task: "{{OS}}:run"
|
||||
|
||||
dev:
|
||||
summary: Runs the application in development mode
|
||||
cmds:
|
||||
- wails3 dev -config ./build/config.yml -port {{.VITE_PORT}}
|
||||
|
||||
setup:docker:
|
||||
summary: Builds Docker image for cross-compilation (~800MB download)
|
||||
cmds:
|
||||
- task: common:setup:docker
|
||||
|
||||
build:server:
|
||||
summary: Builds the application in server mode (no GUI, HTTP server only)
|
||||
cmds:
|
||||
- task: common:build:server
|
||||
|
||||
run:server:
|
||||
summary: Runs the application in server mode
|
||||
cmds:
|
||||
- task: common:run:server
|
||||
|
||||
build:docker:
|
||||
summary: Builds a Docker image for server mode deployment
|
||||
cmds:
|
||||
- task: common:build:docker
|
||||
|
||||
run:docker:
|
||||
summary: Builds and runs the Docker image
|
||||
cmds:
|
||||
- task: common:run:docker
|
||||
56
client/ui/WAILS-DIALOGS.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Wails Dialogs (frontend, `@wailsio/runtime`)
|
||||
|
||||
The frontend dialog API lives in `@wailsio/runtime` as `Dialogs`. Authoritative signatures are in
|
||||
`frontend/node_modules/@wailsio/runtime/types/dialogs.d.ts`.
|
||||
|
||||
See `CLAUDE.md` for project conventions on *when* to use these (errors vs. inline validation, confirmation flow, etc.).
|
||||
|
||||
## Message dialogs
|
||||
|
||||
```ts
|
||||
import { Dialogs } from "@wailsio/runtime";
|
||||
|
||||
await Dialogs.Info({ Title, Message, Buttons?, Detached? });
|
||||
await Dialogs.Warning({ Title, Message, Buttons?, Detached? });
|
||||
await Dialogs.Error({ Title, Message, Buttons?, Detached? });
|
||||
await Dialogs.Question({ Title, Message, Buttons?, Detached? });
|
||||
```
|
||||
|
||||
All four return `Promise<string>` resolving to the **Label** of the button the user clicked. With no `Buttons` provided you get a single OK button — the promise just resolves when the user dismisses.
|
||||
|
||||
`MessageDialogOptions` fields:
|
||||
- `Title?: string` — window title (short).
|
||||
- `Message?: string` — the body text.
|
||||
- `Buttons?: Button[]` — custom buttons. Each `Button` is `{ Label?, IsCancel?, IsDefault? }`. `IsCancel` is what Esc/⌘. triggers; `IsDefault` is what Enter triggers.
|
||||
- `Detached?: boolean` — when `true`, the dialog isn't tied to the parent window (no sheet behavior on macOS).
|
||||
|
||||
## File dialogs
|
||||
|
||||
`Dialogs.OpenFile(options)` and `Dialogs.SaveFile(options)` — see `dialogs.d.ts` for the full `OpenFileDialogOptions` / `SaveFileDialogOptions` field set (filters, ButtonText, multi-select, hidden files, alias resolution, directory mode, etc).
|
||||
|
||||
## Per-OS behavior
|
||||
|
||||
| Platform | Behavior |
|
||||
|---|---|
|
||||
| **macOS** | Sheet-style when attached to a parent window. Up to ~4 custom buttons render naturally. Keyboard: Enter = default, ⌘. or Esc = cancel. Follows system theme. Accessibility is built-in. |
|
||||
| **Windows** | Modal `TaskDialog`-style. Standard button labels are nudged toward OS conventions. Keyboard: Enter = default, Esc = cancel. Follows system theme. |
|
||||
| **Linux** | GTK dialogs — appearance varies by desktop environment (GNOME/KDE). Follows desktop theme. Standard keyboard nav. |
|
||||
|
||||
Behavioural notes that affect us:
|
||||
- The promise resolves with the **button label string**, not an index. Compare against the literal `Label` you passed (e.g. `if (result !== "Delete") return;`).
|
||||
- `Buttons[]` on Linux/Windows uses the labels you supply, but the OS layout/styling is fixed.
|
||||
- `Dialogs.Error` plays the platform error sound and uses the platform error icon. Don't use it for confirmations — use `Dialogs.Warning` or `Dialogs.Question`.
|
||||
- Don't fire dialogs in a tight loop or from every keystroke — they interrupt focus and (on macOS) animate in/out. Debounce or guard with a `busy` flag.
|
||||
|
||||
## Frameless / custom-window dialogs (Go side)
|
||||
|
||||
When the native dialog API isn't enough — rich content, embedded webview, multi-screen flow — open a regular Wails window. This is done on the **Go side** via `app.Window.NewWithOptions(application.WebviewWindowOptions{...})`. Useful options:
|
||||
- `Parent` — attach to a parent so OS treats it as a child.
|
||||
- `AlwaysOnTop: true` — float above the parent.
|
||||
- `Frameless: true` — no titlebar/chrome.
|
||||
- `Resizable: false` (also `DisableResize: true` in v3) — fixed-size dialog feel.
|
||||
- `Hidden: true` initially, then `dialog.Show()` + `dialog.SetFocus()`.
|
||||
|
||||
We **do** use this pattern, but pragmatically: `WindowManager.OpenSettings` and `OpenBrowserLogin` are regular small webview windows (not modal sheets) with no resize, hidden minimise/maximise buttons, and a translucent macOS title bar. They're not classic "OS modal dialogs"; they're just lightweight ancillary windows that look the part. Modal behaviour (`parent.SetEnabled(false)`) is intentionally not used — the user can still click back to the main window.
|
||||
|
||||
In-app modals (`NewProfileDialog`, delete-profile confirmation, etc.) are Radix `Dialog` primitives inside the main webview. Reach for a custom OS window only when content must escape the main window (BrowserLogin is the canonical example — its lifecycle is tied to the SSO wait) or when the window needs its own taskbar entry / dock icon.
|
||||
|
Before Width: | Height: | Size: 4.6 KiB |
|
Before Width: | Height: | Size: 10 KiB |
|
Before Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 7.4 KiB |
BIN
client/ui/assets/netbird-menu-16.png
Normal file
|
After Width: | Height: | Size: 526 B |
BIN
client/ui/assets/netbird-menu-24.png
Normal file
|
After Width: | Height: | Size: 739 B |
BIN
client/ui/assets/netbird-menu-dot-connected-16.png
Normal file
|
After Width: | Height: | Size: 508 B |
BIN
client/ui/assets/netbird-menu-dot-connected-22.png
Normal file
|
After Width: | Height: | Size: 615 B |
BIN
client/ui/assets/netbird-menu-dot-connected.png
Normal file
|
After Width: | Height: | Size: 452 B |
BIN
client/ui/assets/netbird-menu-dot-connecting-16.png
Normal file
|
After Width: | Height: | Size: 520 B |
BIN
client/ui/assets/netbird-menu-dot-connecting-22.png
Normal file
|
After Width: | Height: | Size: 637 B |
BIN
client/ui/assets/netbird-menu-dot-connecting.png
Normal file
|
After Width: | Height: | Size: 452 B |
BIN
client/ui/assets/netbird-menu-dot-error-16.png
Normal file
|
After Width: | Height: | Size: 532 B |
BIN
client/ui/assets/netbird-menu-dot-error-22.png
Normal file
|
After Width: | Height: | Size: 629 B |
BIN
client/ui/assets/netbird-menu-dot-error.png
Normal file
|
After Width: | Height: | Size: 433 B |
BIN
client/ui/assets/netbird-menu-dot-idle-16.png
Normal file
|
After Width: | Height: | Size: 490 B |
BIN
client/ui/assets/netbird-menu-dot-idle-22.png
Normal file
|
After Width: | Height: | Size: 602 B |
BIN
client/ui/assets/netbird-menu-dot-idle.png
Normal file
|
After Width: | Height: | Size: 483 B |
BIN
client/ui/assets/netbird-menu-dot-login-16.png
Normal file
|
After Width: | Height: | Size: 537 B |
BIN
client/ui/assets/netbird-menu-dot-login-22.png
Normal file
|
After Width: | Height: | Size: 641 B |
BIN
client/ui/assets/netbird-menu-dot-login.png
Normal file
|
After Width: | Height: | Size: 475 B |
BIN
client/ui/assets/netbird-menu-dot-offline-16.png
Normal file
|
After Width: | Height: | Size: 512 B |
BIN
client/ui/assets/netbird-menu-dot-offline-22.png
Normal file
|
After Width: | Height: | Size: 605 B |
BIN
client/ui/assets/netbird-menu-dot-offline.png
Normal file
|
After Width: | Height: | Size: 456 B |
|
Before Width: | Height: | Size: 103 KiB |
|
Before Width: | Height: | Size: 3.8 KiB After Width: | Height: | Size: 3.6 KiB |
|
Before Width: | Height: | Size: 103 KiB |
|
Before Width: | Height: | Size: 103 KiB |
|
Before Width: | Height: | Size: 3.8 KiB After Width: | Height: | Size: 3.6 KiB |
|
Before Width: | Height: | Size: 103 KiB |