Compare commits

..

20 Commits

Author SHA1 Message Date
Viktor Liu
cca46f070b Merge remote-tracking branch 'origin/main' into windows-dns-firewall
# Conflicts:
#	client/internal/dns/host_windows.go
2026-05-15 17:11:38 +02:00
Viktor Liu
9ed2e2a5b4 [client] Drop DNS probes for passive health projection (#5971) 2026-05-15 17:07:38 +02:00
Viktor Liu
2ccae7ec47 [client] Mirror v4 exit selection onto v6 pair and honour SkipAutoApply per route (#6150) 2026-05-15 16:58:47 +02:00
Viktor Liu
07e5450117 [management] Bracket IPv6 reverse-proxy target hosts when building URL Host field (#6141) 2026-05-14 16:42:40 +02:00
Viktor Liu
3f914090cb [client] Bracket IPv6 in embed listeners, expand debug bundle (#6134) 2026-05-14 16:22:53 +02:00
Viktor Liu
ea9fab4396 [management] Allocate and preserve IPv6 overlay addresses for embedded proxy peers (#6132) 2026-05-14 16:05:33 +02:00
Vlad
77b479286e [management] fix offline statuses for public proxy clusters (#6133) 2026-05-14 13:27:50 +02:00
Maycon Santos
ab2a8794e7 [client] Add short flags for status command options (#6137)
* [client] Add short flags for status command options

* uppercase filters
2026-05-14 12:30:42 +02:00
Viktor Liu
9126a192ca [client] Set 0644 perms on SSH client config after os.CreateTemp (#6126) 2026-05-12 15:05:53 +02:00
Viktor Liu
1224d6e1ee [client] Persist management URL and pre-shared key overrides on login (#6065) 2026-05-12 14:52:56 +02:00
Nicolas Frati
96672dd1f8 [management] chores: update dex version (#6124)
* chores: update dex version

* chore: update dex fork
2026-05-12 13:50:35 +02:00
Viktor Liu
946ce4c3da [client] Fix --config flag default to point at profile path (#6122) 2026-05-11 17:48:21 +02:00
Vlad
07cbfdbede [proxy] feature: bring your own proxy (#5627) 2026-05-11 14:31:38 +02:00
Viktor Liu
5f8b88471f Initialize dnsFirewall in registryConfigurator tests 2026-05-06 11:56:19 +02:00
Viktor Liu
f42b8aed90 Reject port 0 in NB_DNS_FIREWALL_PORTS and roll back firewall on DNS setup failure 2026-05-05 18:38:45 +02:00
Viktor Liu
0415137acd Address CodeRabbit nits: errors.As, defensive disable, port-aware filter name, log wording, provenance 2026-05-05 18:29:23 +02:00
Viktor Liu
7fd16666e3 Fix Windows lint: handle close error and exclude vendored WFP types from unused 2026-05-05 18:25:53 +02:00
Viktor Liu
0571eeaba0 Move strictMode to Windows-only and add manager unit tests 2026-05-05 18:24:09 +02:00
Viktor Liu
6a201d12b5 Extract applyRouteAll helper and reorder package declarations 2026-05-05 18:16:28 +02:00
Viktor Liu
4810e79a00 Add Windows DNS firewall to block DNS leaks from non-netbird processes 2026-05-05 18:11:06 +02:00
871 changed files with 13338 additions and 42173 deletions

View File

@@ -43,13 +43,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
# which fails to compile until the frontend has been built. The Wails UI
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
# before goreleaser.
# `go list -e` lets the listing succeed even though the embed fails to
# resolve; the grep then drops the broken package by path. Without -e,
# go list aborts with empty stdout and `go test` falls back to the repo
# root, which has no Go files.
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)

View File

@@ -51,7 +51,7 @@ jobs:
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libwebkit2gtk-4.1-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: steps.cache.outputs.cache-hit != 'true'
@@ -141,7 +141,7 @@ jobs:
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libwebkit2gtk-4.1-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
@@ -154,15 +154,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
# which fails to compile until the frontend has been built. The Wails UI
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
# before goreleaser.
# `go list -e` lets the listing succeed even though the embed fails to
# resolve; the grep then drops the broken package by path. Without -e,
# go list aborts with empty stdout and `go test` falls back to the repo
# root, which has no Go files.
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -222,7 +214,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -e -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
'
test_relay:

View File

@@ -64,15 +64,8 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- name: Generate test script
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
# which fails to compile until the frontend has been built. The Wails UI
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
# before goreleaser.
# `go list -e` lets the listing succeed even though the embed fails to
# resolve; the Where-Object pipeline then drops the broken package by
# path. Without -e, go list aborts with empty stdout.
run: |
$packages = go list -e ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } | Where-Object { $_ -notmatch '/client/ui' }
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd

View File

@@ -20,7 +20,7 @@ jobs:
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/**,**/pnpm-lock.yaml,**/package-lock.json
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:
fail-fast: false
@@ -50,16 +50,7 @@ jobs:
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libwebkit2gtk-4.1-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Stub Wails frontend bundle
# client/ui/main.go has //go:embed all:frontend/dist. The
# directory is produced by `pnpm run build` and is gitignored, so
# lint-only runs (no frontend toolchain) need a placeholder file
# for the embed pattern to match.
shell: bash
run: |
mkdir -p client/ui/frontend/dist
touch client/ui/frontend/dist/.embed-placeholder
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with:

View File

@@ -186,9 +186,9 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
@@ -349,18 +349,8 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Set up pnpm
uses: pnpm/action-setup@v3
with:
version: 9
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libwebkit2gtk-4.1-dev libsoup-3.0-dev libayatana-appindicator3-dev gcc-mingw-w64-x86-64
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
@@ -379,16 +369,10 @@ jobs:
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Install wails3 CLI
# Version derived from go.mod so the binding generator always matches
# the wails runtime the binary links against.
run: |
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
- name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
@@ -455,14 +439,6 @@ jobs:
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Set up pnpm
uses: pnpm/action-setup@v3
with:
version: 9
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
@@ -552,6 +528,24 @@ jobs:
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
uses: carlosperate/download-file-action@v2
id: download-mesa3d
if: matrix.arch == 'amd64'
with:
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
file-name: mesa3d.7z
location: ${{ env.downloadPath }}
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
- name: Move opengl32.dll into dist (amd64 only)
if: matrix.arch == 'amd64'
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: carlosperate/download-file-action@v2
with:

View File

@@ -92,6 +92,9 @@ linters:
- linters:
- unused
path: client/firewall/iptables/rule\.go
- linters:
- unused
path: client/internal/dns/dnsfw/(types|syscall|zsyscall)_windows.*\.go
- linters:
- gosec
- mirror
@@ -114,16 +117,6 @@ linters:
- linters:
- staticcheck
text: "QF1012"
# client/ui/main.go uses //go:embed all:frontend/dist; the
# directory is populated by `pnpm build` in the release pipeline
# and missing at lint time, so the embed parses to "no matching
# files found" — surfaced by golangci-lint's typecheck pre-pass.
# Suppress just that one diagnostic; the rest of the package
# (services/, tray.go, grpc.go, ...) still gets linted normally.
- linters:
- typecheck
path: client/ui/main\.go
text: "pattern all:frontend/dist"
paths:
- third_party$
- builtin$

View File

@@ -1,15 +1,6 @@
version: 2
project_name: netbird-ui
before:
hooks:
# Bindings are gitignored; regenerate before the frontend build so
# the @wailsio/runtime Vite plugin can resolve them (vite refuses to
# build without them).
- sh -c 'cd client/ui && wails3 generate bindings -clean=true -ts'
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
builds:
- id: netbird-ui
dir: client/ui
@@ -79,15 +70,12 @@ nfpms:
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/build/linux/netbird.desktop
- src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/build/appicon.png
- src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
- libgtk-3-0
- libwebkit2gtk-4.1-0
- libayatana-appindicator3-1
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
@@ -101,15 +89,12 @@ nfpms:
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/build/linux/netbird.desktop
- src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/build/appicon.png
- src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
- gtk3
- webkit2gtk4.1
- libayatana-appindicator-gtk3
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'

View File

@@ -1,11 +1,6 @@
version: 2
project_name: netbird-ui
before:
hooks:
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
builds:
- id: netbird-ui-darwin
dir: client/ui
@@ -25,6 +20,8 @@ builds:
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin

View File

@@ -143,7 +143,7 @@ func init() {
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", profilemanager.DefaultConfigPath, "Overrides the default profile file location")
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)

View File

@@ -43,16 +43,16 @@ func init() {
ipsFilterMap = make(map[string]struct{})
prefixNamesFilterMap = make(map[string]struct{})
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
}
func statusFunc(cmd *cobra.Command, args []string) error {

View File

@@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
@@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {

View File

@@ -6,7 +6,7 @@
!define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
!define INSTALLER_NAME "netbird-installer.exe"
!define MAIN_APP_EXE "Netbird"
!define ICON "ui\\build\\windows\\icon.ico"
!define ICON "ui\\assets\\netbird.ico"
!define BANNER "ui\\build\\banner.bmp"
!define LICENSE_DATA "..\\LICENSE"
@@ -280,43 +280,6 @@ CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
SectionEnd
# Install the Microsoft Edge WebView2 runtime if it isn't already present.
# Macro adapted from Wails3's NSIS template (wails_tools.nsh): a registry
# probe followed by a silent install of the embedded evergreen bootstrapper.
# The MicrosoftEdgeWebview2Setup.exe payload is staged next to this script
# by the sign-pipelines build step (`wails3 generate webview2bootstrapper`).
!macro nb.webview2runtime
SetRegView 64
# Per-machine install marker — populated when the runtime ships with
# Edge or has been installed by an admin previously.
ReadRegStr $0 HKLM "SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
${If} $0 != ""
Goto webview2_ok
${EndIf}
# Per-user fallback for HKCU installs.
ReadRegStr $0 HKCU "Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
${If} $0 != ""
Goto webview2_ok
${EndIf}
SetDetailsPrint both
DetailPrint "Installing: WebView2 Runtime"
SetDetailsPrint listonly
InitPluginsDir
CreateDirectory "$pluginsdir\webview2bootstrapper"
SetOutPath "$pluginsdir\webview2bootstrapper"
File "MicrosoftEdgeWebview2Setup.exe"
ExecWait '"$pluginsdir\webview2bootstrapper\MicrosoftEdgeWebview2Setup.exe" /silent /install'
SetDetailsPrint both
webview2_ok:
!macroend
Section -WebView2
!insertmacro nb.webview2runtime
SectionEnd
Section -Post
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service start'
@@ -363,9 +326,9 @@ DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll"
# Legacy: pre-Wails installs shipped opengl32.dll (Mesa3D for Fyne); remove
# any leftover copy on uninstall so old upgrades don't leave it behind.
!if ${ARCH} == "amd64"
Delete "$INSTDIR\opengl32.dll"
!endif
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR"

View File

@@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")
@@ -258,15 +256,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
// On daemon shutdown / Down() the parent context is cancelled
// and the dial fails with "context canceled". Wrapping that
// into state would leave the snapshot stuck at Connecting+err
// until the backoff loop wakes up — instead let the operation
// return cleanly so the deferred state.Set(StatusIdle) takes
// effect on the next iteration.
if c.ctx.Err() != nil {
return nil
}
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
}
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
@@ -435,11 +424,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
c.statusRecorder.ClientStart()
// Wrap the backoff with c.ctx so Down()/actCancel propagates into the
// inter-attempt sleep — otherwise a 15s MaxInterval can keep the retry
// loop alive long after the caller asked to give up, leaving the
// status stream stuck at Connecting.
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
err = backoff.Retry(operation, backOff)
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {

View File

@@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
@@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird
Other non-sensitive configuration options are included without anonymization.
Firewall Rules (Linux only)
The bundle includes two separate firewall rule files:
The bundle includes the following firewall-related files:
iptables.txt:
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
- Includes all tables (filter, nat, mangle, raw, security)
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
ip6tables.txt:
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
- Same table coverage and anonymization as iptables.txt
- Omitted when ip6tables is not installed or no IPv6 rules are present
ipset.txt:
- Output of 'ipset list' (family-agnostic)
- IP addresses are anonymized; set names and types remain unchanged
nftables.txt:
- Complete nftables ruleset obtained via 'nft -a list ruleset'
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
- Includes rule handle numbers and packet counters
- All tables, chains, and rules are included
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
- All IP addresses are anonymized; chain/table names remain unchanged
sysctls.txt:
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
- Interface names anonymized when --anonymize is set
IP Rules (Linux only)
The ip_rules.txt file contains detailed IP routing rule information:
@@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() {
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
}
if err := g.addSysctls(); err != nil {
log.Errorf("failed to add sysctls to debug bundle: %v", err)
}
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}

View File

@@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) {
// addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules")
iptablesRules, err := collectIPTablesRules()
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect iptables rules: %v", err)
log.Warnf("Failed to collect ipset information: %v", err)
} else {
if g.anonymize {
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
}
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
log.Warnf("Failed to add iptables rules to bundle: %v", err)
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
log.Warnf("Failed to add ipset output to bundle: %v", err)
}
}
@@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error {
return nil
}
// collectIPTablesRules collects rules using both iptables-save and verbose listing
func collectIPTablesRules() (string, error) {
var builder strings.Builder
saveOutput, err := collectIPTablesSave()
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
rules, err := collectIPTablesRules(saveBin, listBin)
if err != nil {
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
} else {
builder.WriteString("=== iptables-save output ===\n")
log.Warnf("Failed to collect %s rules: %v", listBin, err)
return
}
if g.anonymize {
rules = g.anonymizer.AnonymizeString(rules)
}
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
}
}
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
// Returns an error when neither command produced any output (e.g. the binary is missing),
// so the caller can skip writing an empty file.
func collectIPTablesRules(saveBin, listBin string) (string, error) {
var builder strings.Builder
var collected bool
var firstErr error
saveOutput, err := runCommand(saveBin)
switch {
case err != nil:
firstErr = err
log.Warnf("Failed to collect %s output: %v", saveBin, err)
case strings.TrimSpace(saveOutput) == "":
log.Debugf("%s produced no output, skipping", saveBin)
default:
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
builder.WriteString(saveOutput)
builder.WriteString("\n")
collected = true
}
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n")
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
builder.WriteString(listHeader)
tables := []string{"filter", "nat", "mangle", "raw", "security"}
for _, table := range tables {
builder.WriteString(fmt.Sprintf("*%s\n", table))
stats, err := getTableStatistics(table)
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
if err != nil {
log.Warnf("Failed to get statistics for table %s: %v", table, err)
if firstErr == nil {
firstErr = err
}
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
continue
}
builder.WriteString(fmt.Sprintf("*%s\n", table))
builder.WriteString(stats)
builder.WriteString("\n")
collected = true
}
if !collected {
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
}
return builder.String(), nil
}
@@ -214,34 +238,15 @@ func collectIPSets() (string, error) {
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save")
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
func runCommand(name string, args ...string) (string, error) {
cmd := exec.Command(name, args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
}
rules := stdout.String()
if strings.TrimSpace(rules) == "" {
return "", fmt.Errorf("no iptables rules found")
}
return rules, nil
}
// getTableStatistics gets verbose statistics for an entire table using iptables command
func getTableStatistics(table string) (string, error) {
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
}
return stdout.String(), nil
@@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
return fmt.Sprintf("type-%v", keyType)
}
}
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
func (g *BundleGenerator) addSysctls() error {
log.Info("Collecting sysctls")
content := collectSysctls()
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
return fmt.Errorf("add sysctls to bundle: %w", err)
}
return nil
}
// collectSysctls reads every sysctl that the netbird client may modify, plus
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
func collectSysctls() string {
var builder strings.Builder
writeSysctlGroup(&builder, "forwarding", []string{
"net.ipv4.ip_forward",
"net.ipv6.conf.all.forwarding",
"net.ipv6.conf.default.forwarding",
})
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
writeSysctlGroup(&builder, "rp_filter", append(
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
listInterfaceSysctls("ipv4", "rp_filter")...,
))
writeSysctlGroup(&builder, "src_valid_mark", append(
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
listInterfaceSysctls("ipv4", "src_valid_mark")...,
))
writeSysctlGroup(&builder, "conntrack", []string{
"net.netfilter.nf_conntrack_acct",
"net.netfilter.nf_conntrack_tcp_loose",
})
writeSysctlGroup(&builder, "tcp", []string{
"net.ipv4.tcp_tw_reuse",
})
return builder.String()
}
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
for _, key := range keys {
value, err := readSysctl(key)
if err != nil {
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
continue
}
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
}
builder.WriteString("\n")
}
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
// (callers add those explicitly so they appear first).
func listInterfaceSysctls(family, leaf string) []string {
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
var keys []string
for _, e := range entries {
name := e.Name()
if name == "all" || name == "default" {
continue
}
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
}
sort.Strings(keys)
return keys
}
func readSysctl(key string) (string, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
value, err := os.ReadFile(path)
if err != nil {
return "", err
}
return strings.TrimSpace(string(value)), nil
}

View File

@@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error {
// IP rules are only supported on Linux
return nil
}
func (g *BundleGenerator) addSysctls() error {
// Sysctl collection is only supported on Linux
return nil
}

View File

@@ -0,0 +1,63 @@
package dnsfw
import (
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
const (
// EnvDisable disables the DNS firewall entirely when set to a truthy value.
EnvDisable = "NB_DISABLE_DNS_FIREWALL"
// EnvPorts overrides the comma-separated list of remote ports to block.
// Empty disables the firewall.
EnvPorts = "NB_DNS_FIREWALL_PORTS"
// EnvStrict enables strict mode: permit DNS only to the virtual DNS IP
// and the netbird daemon. Default mode also permits anything on the
// netbird tunnel interface, which is safer if NRPT is silently ignored
// by Windows but lets apps reach custom DNS servers via the tunnel.
EnvStrict = "NB_DNS_FIREWALL_STRICT"
)
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
var defaultBlockedPorts = []uint16{53, 853}
// blockedPorts returns the effective port list, honoring env overrides.
// A nil return means the firewall should not be installed.
func blockedPorts() []uint16 {
if disabled, _ := strconv.ParseBool(os.Getenv(EnvDisable)); disabled {
log.Infof("dns firewall disabled via %s", EnvDisable)
return nil
}
override, ok := os.LookupEnv(EnvPorts)
if !ok {
return defaultBlockedPorts
}
var ports []uint16
for _, raw := range strings.Split(override, ",") {
raw = strings.TrimSpace(raw)
if raw == "" {
continue
}
port, err := strconv.ParseUint(raw, 10, 16)
if err != nil {
log.Warnf("dns firewall: ignoring invalid port %q in %s: %v", raw, EnvPorts, err)
continue
}
if port == 0 {
log.Warnf("dns firewall: ignoring port 0 in %s", EnvPorts)
continue
}
ports = append(ports, uint16(port))
}
if len(ports) == 0 {
log.Infof("dns firewall disabled: %s yielded no valid ports", EnvPorts)
return nil
}
return ports
}

View File

@@ -0,0 +1,39 @@
package dnsfw
import (
"reflect"
"testing"
)
func TestBlockedPorts(t *testing.T) {
tests := []struct {
name string
disable string
ports string
setPorts bool
want []uint16
}{
{name: "default", want: defaultBlockedPorts},
{name: "disabled", disable: "true", want: nil},
{name: "disabled false keeps default", disable: "false", want: defaultBlockedPorts},
{name: "override single port", ports: "53", setPorts: true, want: []uint16{53}},
{name: "override multi", ports: "53, 853 ,5353", setPorts: true, want: []uint16{53, 853, 5353}},
{name: "override empty disables", ports: "", setPorts: true, want: nil},
{name: "override invalid skipped", ports: "53,not-a-port,853", setPorts: true, want: []uint16{53, 853}},
{name: "override zero skipped", ports: "53,0,853", setPorts: true, want: []uint16{53, 853}},
{name: "override only invalid disables", ports: "abc", setPorts: true, want: nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvDisable, tc.disable)
if tc.setPorts {
t.Setenv(EnvPorts, tc.ports)
}
got := blockedPorts()
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("blockedPorts() = %v, want %v", got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,16 @@
// Package dnsfw blocks DNS traffic from non-netbird processes when netbird is
// managing the host's DNS, so that resolvers running on apps or libraries
// outside netbird cannot bypass the configured DNS path.
//
// Implementation is Windows-only (uses WFP). On other platforms New returns
// a no-op manager.
package dnsfw
import "net/netip"
// Manager controls the per-tunnel DNS firewall. Both methods must be safe
// to call multiple times.
type Manager interface {
Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
Disable() error
}

View File

@@ -0,0 +1,15 @@
//go:build !windows
package dnsfw
import "net/netip"
type noopManager struct{}
func (noopManager) Enable(string, netip.Addr) error { return nil }
func (noopManager) Disable() error { return nil }
// New returns a no-op manager on non-Windows platforms.
func New() Manager {
return noopManager{}
}

View File

@@ -0,0 +1,144 @@
//go:build windows
package dnsfw
import (
"fmt"
"net/netip"
"os"
"strconv"
"sync"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
)
type windowsManager struct {
mu sync.Mutex
// session is the WFP engine handle. Zero when disabled.
session uintptr
}
// Enable installs the dns firewall. Strict mode propagates failures;
// non-strict mode logs and returns nil so partial protection is preserved.
func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error {
m.mu.Lock()
defer m.mu.Unlock()
ports := blockedPorts()
if len(ports) == 0 {
return nil
}
if m.session != 0 {
if err := m.disableLocked(); err != nil {
return fmt.Errorf("reset existing dns firewall session: %w", err)
}
}
strict := strictMode()
luid, err := luidFromGUID(ifaceGUID)
if err != nil {
return m.failOrLog(strict, fmt.Errorf("resolve tun luid from guid %s: %w", ifaceGUID, err))
}
exe, err := os.Executable()
if err != nil {
return m.failOrLog(strict, fmt.Errorf("resolve daemon executable path: %w", err))
}
cfg := installConfig{
tunLUID: luid,
daemonExe: exe,
blockedPorts: ports,
strict: strict,
virtualDNSIP: virtualDNSIP,
}
// session==0 signals a hard failure; non-zero with non-nil err is a partial install.
session, installErr := installFilters(cfg)
if session == 0 {
return m.failOrLog(strict, fmt.Errorf("install dns firewall filters: %w", installErr))
}
if installErr != nil && strict {
_ = closeSession(session)
return fmt.Errorf("strict dns firewall: partial install: %w", installErr)
}
m.session = session
log.Infof("dns firewall installed: iface=%s daemon=%s ports=%v strict=%v virtual_dns=%s",
ifaceGUID, exe, ports, strict, virtualDNSIP)
if installErr != nil {
log.Warnf("dns firewall partially installed (some filters failed): %v", installErr)
}
return nil
}
func (m *windowsManager) Disable() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.disableLocked()
}
func (m *windowsManager) disableLocked() error {
if m.session == 0 {
return nil
}
session := m.session
m.session = 0
if err := closeSession(session); err != nil {
return fmt.Errorf("close wfp session: %w", err)
}
log.Info("dns firewall removed")
return nil
}
// failOrLog returns err unchanged in strict mode. In non-strict mode the
// error is logged and nil is returned.
func (m *windowsManager) failOrLog(strict bool, err error) error {
if strict {
return err
}
log.Errorf("dns firewall: %v", err)
return nil
}
// New returns a Windows DNS firewall manager backed by WFP.
func New() Manager {
return &windowsManager{}
}
// strictMode reports whether strict mode is enabled via env.
func strictMode() bool {
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
return v
}
// luidFromGUID converts a Windows interface GUID string to its LUID.
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in luidFromGUID: %v", r)
}
}()
guid, err := windows.GUIDFromString(ifaceGUID)
if err != nil {
return 0, fmt.Errorf("parse guid: %w", err)
}
rc, _, _ := procConvertInterfaceGuidToLuid.Call(
uintptr(unsafe.Pointer(&guid)),
uintptr(unsafe.Pointer(&luid)),
)
if rc != 0 {
return 0, fmt.Errorf("ConvertInterfaceGuidToLuid returned %d", rc)
}
return luid, nil
}

View File

@@ -0,0 +1,72 @@
//go:build windows
package dnsfw
import (
"net/netip"
"os"
"testing"
)
func TestStrictMode(t *testing.T) {
tests := []struct {
name string
val string
set bool
want bool
}{
{name: "unset", want: false},
{name: "true", val: "true", set: true, want: true},
{name: "1", val: "1", set: true, want: true},
{name: "false", val: "false", set: true, want: false},
{name: "invalid is false", val: "garbage", set: true, want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvStrict, tc.val)
if !tc.set {
os.Unsetenv(EnvStrict)
}
if got := strictMode(); got != tc.want {
t.Fatalf("strictMode() = %v, want %v", got, tc.want)
}
})
}
}
func TestWindowsManagerDisableIdempotent(t *testing.T) {
m := &windowsManager{}
if err := m.Disable(); err != nil {
t.Fatalf("first Disable on fresh manager: %v", err)
}
if err := m.Disable(); err != nil {
t.Fatalf("second Disable on fresh manager: %v", err)
}
if m.session != 0 {
t.Fatalf("session should remain zero, got %d", m.session)
}
}
func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) {
t.Setenv(EnvDisable, "true")
m := &windowsManager{}
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err)
}
if m.session != 0 {
t.Fatalf("session must remain zero when env disables firewall, got %d", m.session)
}
}
func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) {
t.Setenv(EnvPorts, "")
m := &windowsManager{}
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
t.Fatalf("Enable should be a no-op when ports list is empty: %v", err)
}
if m.session != 0 {
t.Fatalf("session must remain zero when ports list is empty, got %d", m.session)
}
}

View File

@@ -0,0 +1,53 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/helpers.go.
*/
package dnsfw
import (
"errors"
"fmt"
"runtime"
"syscall"
"golang.org/x/sys/windows"
)
func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
namePtr, err := windows.UTF16PtrFromString(name)
if err != nil {
return nil, wrapErr(err)
}
descriptionPtr, err := windows.UTF16PtrFromString(description)
if err != nil {
return nil, wrapErr(err)
}
return &wtFwpmDisplayData0{
name: namePtr,
description: descriptionPtr,
}, nil
}
func filterWeight(weight uint8) wtFwpValue0 {
return wtFwpValue0{
_type: cFWP_UINT8,
value: uintptr(weight),
}
}
func wrapErr(err error) error {
var errno syscall.Errno
if !errors.As(err, &errno) {
return err
}
_, file, line, ok := runtime.Caller(1)
if !ok {
return fmt.Errorf("wfp error at unknown location: %w", err)
}
return fmt.Errorf("wfp error at %s:%d: %w", file, line, err)
}

View File

@@ -0,0 +1,249 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
*
* Filter installers adapted from wireguard-windows tunnel/firewall/rules.go.
* The block-DNS approach (port 53 + UDP/TCP) matches what wireguard-windows
* uses for its kill-switch DNS leak protection. We extend it with a
* configurable port set so we also cover :853 (DoT) and any future ports.
*/
package dnsfw
import (
"encoding/binary"
"fmt"
"net/netip"
"unsafe"
"github.com/hashicorp/go-multierror"
"golang.org/x/sys/windows"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// Filters install at outbound ALE_AUTH_CONNECT layers only; inbound replies
// follow the authorized outbound flow.
// permitTunInterface installs a permit filter for any traffic whose local
// interface is the netbird tunnel.
func permitTunInterface(session uintptr, base *baseObjects, weight uint8, ifLUID uint64) error {
cond := wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_LOCAL_INTERFACE,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT64,
value: uintptr(unsafe.Pointer(&ifLUID)),
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: 1,
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
}
return addOutboundFilters(session, &filter, "Permit netbird tunnel")
}
// permitDaemonByAppID installs a permit filter matching the netbird daemon
// executable by App-ID. App-ID alone is sufficient because netbird.exe is a
// dedicated binary.
func permitDaemonByAppID(session uintptr, base *baseObjects, daemonExe string, weight uint8) error {
appID, err := daemonAppID(daemonExe)
if err != nil {
return err
}
defer fwpmFreeMemory0(unsafe.Pointer(&appID))
cond := wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_ALE_APP_ID,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_BYTE_BLOB_TYPE,
value: uintptr(unsafe.Pointer(appID)),
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: 1,
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
}
return addOutboundFilters(session, &filter, "Permit netbird daemon")
}
// permitVirtualDNSIP installs a permit filter for DNS-port traffic destined
// for the in-tunnel virtual DNS IP. Used in strict mode in lieu of
// permitTunInterface.
func permitVirtualDNSIP(session uintptr, base *baseObjects, ip netip.Addr, ports []uint16, weight uint8) error {
var merr *multierror.Error
for _, port := range ports {
if err := permitDNSToHost(session, base, ip, port, weight); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit %s:%d: %w", ip, port, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func permitDNSToHost(session uintptr, base *baseObjects, ip netip.Addr, port uint16, weight uint8) error {
if !ip.IsValid() {
return fmt.Errorf("invalid address")
}
var addrCond wtFwpmFilterCondition0
var layer windows.GUID
// v6 backing must outlive fwpmFilterAdd0; keep it on this stack frame.
var v6 wtFwpByteArray16
if ip.Is4() {
v4 := ip.As4()
addrCond = wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT32,
value: uintptr(binary.BigEndian.Uint32(v4[:])),
},
}
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V4
} else {
v6 = wtFwpByteArray16{byteArray16: ip.As16()}
addrCond = wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_BYTE_ARRAY16_TYPE,
value: uintptr(unsafe.Pointer(&v6)),
},
}
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V6
}
conditions := [2]wtFwpmFilterCondition0{
addrCond,
{
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT16,
value: uintptr(port),
},
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: uint32(len(conditions)),
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
}
display, err := createWtFwpmDisplayData0(fmt.Sprintf("Permit DNS to %s:%d", ip, port), "")
if err != nil {
return wrapErr(err)
}
filter.displayData = *display
filter.layerKey = layer
var filterID uint64
if err := fwpmFilterAdd0(session, &filter, 0, &filterID); err != nil {
return wrapErr(err)
}
_ = v6
return nil
}
// blockDNSPorts installs a deny filter for outbound traffic to each of the
// given remote ports over UDP or TCP. Per-port and per-layer failures are
// accumulated; partial coverage is preferred over zero coverage.
func blockDNSPorts(session uintptr, base *baseObjects, ports []uint16, weight uint8) error {
var merr *multierror.Error
for _, port := range ports {
if err := blockDNSPort(session, base, port, weight); err != nil {
merr = multierror.Append(merr, fmt.Errorf("block port %d: %w", port, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func blockDNSPort(session uintptr, base *baseObjects, port uint16, weight uint8) error {
conditions := [3]wtFwpmFilterCondition0{
{
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT16,
value: uintptr(port),
},
},
{
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT8,
value: uintptr(cIPPROTO_UDP),
},
},
// Repeat the IP_PROTOCOL condition for logical OR with TCP.
{
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT8,
value: uintptr(cIPPROTO_TCP),
},
},
}
filter := wtFwpmFilter0{
providerKey: &base.provider,
subLayerKey: base.filters,
weight: filterWeight(weight),
numFilterConditions: uint32(len(conditions)),
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
action: wtFwpmAction0{_type: cFWP_ACTION_BLOCK},
}
return addOutboundFilters(session, &filter, fmt.Sprintf("Block DNS port %d", port))
}
// addOutboundFilters installs the same filter on the v4 and v6 outbound ALE
// connect layers. v4 and v6 are installed independently: failure on one
// layer does not abort the other, and the accumulated errors are returned.
// Partial coverage is preferred over zero coverage.
func addOutboundFilters(session uintptr, filter *wtFwpmFilter0, name string) error {
layers := [...]struct {
layer windows.GUID
label string
}{
{cFWPM_LAYER_ALE_AUTH_CONNECT_V4, name + " (IPv4)"},
{cFWPM_LAYER_ALE_AUTH_CONNECT_V6, name + " (IPv6)"},
}
var merr *multierror.Error
for _, l := range layers {
display, err := createWtFwpmDisplayData0(l.label, "")
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
continue
}
filter.displayData = *display
filter.layerKey = l.layer
var filterID uint64
if err := fwpmFilterAdd0(session, filter, 0, &filterID); err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,177 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
*
* Session lifecycle and the high-level Install/Close entry points adapted
* from wireguard-windows tunnel/firewall.
*/
package dnsfw
import (
"errors"
"fmt"
"net/netip"
"unsafe"
"github.com/hashicorp/go-multierror"
"golang.org/x/sys/windows"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// installConfig is the input to installFilters.
type installConfig struct {
tunLUID uint64
daemonExe string
blockedPorts []uint16
// strict, when true, narrows the carve-out from "anything on tun" to
// "DNS only to virtualDNSIP". virtualDNSIP must be valid in this case.
strict bool
virtualDNSIP netip.Addr
}
// baseObjects holds the GUIDs of the WFP provider and sublayer registered
// for our session. Both are randomly generated per session.
type baseObjects struct {
provider windows.GUID
filters windows.GUID
}
// installFilters opens a dynamic WFP session and installs the netbird DNS
// firewall filters. Returns a zero session on hard failure (session create,
// base objects); a non-zero session with a non-nil error is a partial install
// (some per-filter installs failed) and is safe to close.
func installFilters(cfg installConfig) (session uintptr, err error) {
defer func() {
if r := recover(); r != nil {
// Dynamic session: kernel will clean up on process exit even
// if we leave the handle dangling here.
err = fmt.Errorf("panic in installFilters: %v", r)
}
}()
if len(cfg.blockedPorts) == 0 {
return 0, errors.New("dns firewall: no blocked ports configured")
}
if cfg.strict && !cfg.virtualDNSIP.IsValid() {
return 0, errors.New("dns firewall: strict mode requires a valid virtual DNS IP")
}
session, err = createSession()
if err != nil {
return 0, err
}
base, err := registerBaseObjects(session)
if err != nil {
_ = fwpmEngineClose0(session)
return 0, fmt.Errorf("register base objects: %w", err)
}
var merr *multierror.Error
if cfg.strict {
if err := permitVirtualDNSIP(session, base, cfg.virtualDNSIP, cfg.blockedPorts, 15); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit virtual dns: %w", err))
}
} else {
if err := permitTunInterface(session, base, 15, cfg.tunLUID); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit tun interface: %w", err))
}
}
if err := permitDaemonByAppID(session, base, cfg.daemonExe, 14); err != nil {
merr = multierror.Append(merr, fmt.Errorf("permit netbird daemon: %w", err))
}
if err := blockDNSPorts(session, base, cfg.blockedPorts, 10); err != nil {
merr = multierror.Append(merr, fmt.Errorf("block dns ports: %w", err))
}
return session, nberrors.FormatErrorOrNil(merr)
}
// closeSession tears down a WFP session previously opened by installFilters.
// All filters owned by the session are removed.
func closeSession(session uintptr) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in closeSession: %v", r)
}
}()
if session == 0 {
return nil
}
if err := fwpmEngineClose0(session); err != nil {
return wrapErr(err)
}
return nil
}
func createSession() (uintptr, error) {
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall dynamic session")
if err != nil {
return 0, wrapErr(err)
}
session := wtFwpmSession0{
displayData: *displayData,
flags: cFWPM_SESSION_FLAG_DYNAMIC,
txnWaitTimeoutInMSec: windows.INFINITE,
}
var handle uintptr
if err := fwpmEngineOpen0(nil, cRPC_C_AUTHN_WINNT, nil, &session, unsafe.Pointer(&handle)); err != nil {
return 0, wrapErr(err)
}
return handle, nil
}
func registerBaseObjects(session uintptr) (*baseObjects, error) {
bo := &baseObjects{}
var err error
if bo.provider, err = windows.GenerateGUID(); err != nil {
return nil, wrapErr(err)
}
if bo.filters, err = windows.GenerateGUID(); err != nil {
return nil, wrapErr(err)
}
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall provider")
if err != nil {
return nil, wrapErr(err)
}
provider := wtFwpmProvider0{
providerKey: bo.provider,
displayData: *displayData,
}
if err := fwpmProviderAdd0(session, &provider, 0); err != nil {
return nil, wrapErr(err)
}
subDisplay, err := createWtFwpmDisplayData0("NetBird DNS firewall filters", "Permit and block filters")
if err != nil {
return nil, wrapErr(err)
}
sublayer := wtFwpmSublayer0{
subLayerKey: bo.filters,
displayData: *subDisplay,
providerKey: &bo.provider,
weight: ^uint16(0),
}
if err := fwpmSubLayerAdd0(session, &sublayer, 0); err != nil {
return nil, wrapErr(err)
}
return bo, nil
}
// daemonAppID returns the WFP App-ID byte blob for the given executable path.
func daemonAppID(path string) (*wtFwpByteBlob, error) {
pathPtr, err := windows.UTF16PtrFromString(path)
if err != nil {
return nil, wrapErr(err)
}
var appID *wtFwpByteBlob
if err := fwpmGetAppIdFromFileName0(pathPtr, unsafe.Pointer(&appID)); err != nil {
return nil, wrapErr(err)
}
return appID, nil
}

View File

@@ -0,0 +1,38 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/syscall_windows.go.
*/
package dnsfw
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineopen0
//sys fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmEngineOpen0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineclose0
//sys fwpmEngineClose0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmEngineClose0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmsublayeradd0
//sys fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmSubLayerAdd0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmgetappidfromfilename0
//sys fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmGetAppIdFromFileName0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfreememory0
//sys fwpmFreeMemory0(p unsafe.Pointer) = fwpuclnt.FwpmFreeMemory0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfilteradd0
//sys fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) [failretval!=0] = fwpuclnt.FwpmFilterAdd0
// https://docs.microsoft.com/en-us/windows/desktop/api/Fwpmu/nf-fwpmu-fwpmtransactionbegin0
//sys fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionBegin0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactioncommit0
//sys fwpmTransactionCommit0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionCommit0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactionabort0
//sys fwpmTransactionAbort0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionAbort0
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmprovideradd0
//sys fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmProviderAdd0

View File

@@ -0,0 +1,414 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/types_windows.go.
*/
package dnsfw
import "golang.org/x/sys/windows"
const (
anysizeArray = 1 // ANYSIZE_ARRAY defined in winnt.h
wtFwpBitmapArray64_Size = 8
wtFwpByteArray16_Size = 16
wtFwpByteArray6_Size = 6
wtFwpmAction0_Size = 20
wtFwpmAction0_filterType_Offset = 4
wtFwpV4AddrAndMask_Size = 8
wtFwpV4AddrAndMask_mask_Offset = 4
wtFwpV6AddrAndMask_Size = 17
wtFwpV6AddrAndMask_prefixLength_Offset = 16
)
type wtFwpActionFlag uint32
const (
cFWP_ACTION_FLAG_TERMINATING wtFwpActionFlag = 0x00001000
cFWP_ACTION_FLAG_NON_TERMINATING wtFwpActionFlag = 0x00002000
cFWP_ACTION_FLAG_CALLOUT wtFwpActionFlag = 0x00004000
)
// FWP_ACTION_TYPE defined in fwptypes.h
type wtFwpActionType uint32
const (
cFWP_ACTION_BLOCK wtFwpActionType = wtFwpActionType(0x00000001 | cFWP_ACTION_FLAG_TERMINATING)
cFWP_ACTION_PERMIT wtFwpActionType = wtFwpActionType(0x00000002 | cFWP_ACTION_FLAG_TERMINATING)
cFWP_ACTION_CALLOUT_TERMINATING wtFwpActionType = wtFwpActionType(0x00000003 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_TERMINATING)
cFWP_ACTION_CALLOUT_INSPECTION wtFwpActionType = wtFwpActionType(0x00000004 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_NON_TERMINATING)
cFWP_ACTION_CALLOUT_UNKNOWN wtFwpActionType = wtFwpActionType(0x00000005 | cFWP_ACTION_FLAG_CALLOUT)
cFWP_ACTION_CONTINUE wtFwpActionType = wtFwpActionType(0x00000006 | cFWP_ACTION_FLAG_NON_TERMINATING)
cFWP_ACTION_NONE wtFwpActionType = 0x00000007
cFWP_ACTION_NONE_NO_MATCH wtFwpActionType = 0x00000008
cFWP_ACTION_BITMAP_INDEX_SET wtFwpActionType = 0x00000009
)
// FWP_BYTE_BLOB defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_blob_)
type wtFwpByteBlob struct {
size uint32
data *uint8
}
// FWP_MATCH_TYPE defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_match_type_)
type wtFwpMatchType uint32
const (
cFWP_MATCH_EQUAL wtFwpMatchType = 0
cFWP_MATCH_GREATER wtFwpMatchType = cFWP_MATCH_EQUAL + 1
cFWP_MATCH_LESS wtFwpMatchType = cFWP_MATCH_GREATER + 1
cFWP_MATCH_GREATER_OR_EQUAL wtFwpMatchType = cFWP_MATCH_LESS + 1
cFWP_MATCH_LESS_OR_EQUAL wtFwpMatchType = cFWP_MATCH_GREATER_OR_EQUAL + 1
cFWP_MATCH_RANGE wtFwpMatchType = cFWP_MATCH_LESS_OR_EQUAL + 1
cFWP_MATCH_FLAGS_ALL_SET wtFwpMatchType = cFWP_MATCH_RANGE + 1
cFWP_MATCH_FLAGS_ANY_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ALL_SET + 1
cFWP_MATCH_FLAGS_NONE_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ANY_SET + 1
cFWP_MATCH_EQUAL_CASE_INSENSITIVE wtFwpMatchType = cFWP_MATCH_FLAGS_NONE_SET + 1
cFWP_MATCH_NOT_EQUAL wtFwpMatchType = cFWP_MATCH_EQUAL_CASE_INSENSITIVE + 1
cFWP_MATCH_PREFIX wtFwpMatchType = cFWP_MATCH_NOT_EQUAL + 1
cFWP_MATCH_NOT_PREFIX wtFwpMatchType = cFWP_MATCH_PREFIX + 1
cFWP_MATCH_TYPE_MAX wtFwpMatchType = cFWP_MATCH_NOT_PREFIX + 1
)
// FWPM_ACTION0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_action0_)
type wtFwpmAction0 struct {
_type wtFwpActionType
filterType windows.GUID // Windows type: GUID
}
// Defined in fwpmu.h. 4cd62a49-59c3-4969-b7f3-bda5d32890a4
var cFWPM_CONDITION_IP_LOCAL_INTERFACE = windows.GUID{
Data1: 0x4cd62a49,
Data2: 0x59c3,
Data3: 0x4969,
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
}
// Defined in fwpmu.h. b235ae9a-1d64-49b8-a44c-5ff3d9095045
var cFWPM_CONDITION_IP_REMOTE_ADDRESS = windows.GUID{
Data1: 0xb235ae9a,
Data2: 0x1d64,
Data3: 0x49b8,
Data4: [8]byte{0xa4, 0x4c, 0x5f, 0xf3, 0xd9, 0x09, 0x50, 0x45},
}
// Defined in fwpmu.h. 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
var cFWPM_CONDITION_IP_PROTOCOL = windows.GUID{
Data1: 0x3971ef2b,
Data2: 0x623e,
Data3: 0x4f9a,
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
}
// Defined in fwpmu.h. 0c1ba1af-5765-453f-af22-a8f791ac775b
var cFWPM_CONDITION_IP_LOCAL_PORT = windows.GUID{
Data1: 0x0c1ba1af,
Data2: 0x5765,
Data3: 0x453f,
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
}
// Defined in fwpmu.h. c35a604d-d22b-4e1a-91b4-68f674ee674b
var cFWPM_CONDITION_IP_REMOTE_PORT = windows.GUID{
Data1: 0xc35a604d,
Data2: 0xd22b,
Data3: 0x4e1a,
Data4: [8]byte{0x91, 0xb4, 0x68, 0xf6, 0x74, 0xee, 0x67, 0x4b},
}
// Defined in fwpmu.h. d78e1e87-8644-4ea5-9437-d809ecefc971
var cFWPM_CONDITION_ALE_APP_ID = windows.GUID{
Data1: 0xd78e1e87,
Data2: 0x8644,
Data3: 0x4ea5,
Data4: [8]byte{0x94, 0x37, 0xd8, 0x09, 0xec, 0xef, 0xc9, 0x71},
}
// af043a0a-b34d-4f86-979c-c90371af6e66
var cFWPM_CONDITION_ALE_USER_ID = windows.GUID{
Data1: 0xaf043a0a,
Data2: 0xb34d,
Data3: 0x4f86,
Data4: [8]byte{0x97, 0x9c, 0xc9, 0x03, 0x71, 0xaf, 0x6e, 0x66},
}
// d9ee00de-c1ef-4617-bfe3-ffd8f5a08957
var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
Data1: 0xd9ee00de,
Data2: 0xc1ef,
Data3: 0x4617,
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
}
var (
cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
)
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
Data1: 0x7bc43cbf,
Data2: 0x37ba,
Data3: 0x45f1,
Data4: [8]byte{0xb7, 0x4a, 0x82, 0xff, 0x51, 0x8e, 0xeb, 0x10},
}
type wtFwpmL2Flags uint32
const cFWP_CONDITION_L2_IS_VM2VM wtFwpmL2Flags = 0x00000010
var cFWPM_CONDITION_FLAGS = windows.GUID{
Data1: 0x632ce23b,
Data2: 0x5167,
Data3: 0x435c,
Data4: [8]byte{0x86, 0xd7, 0xe9, 0x03, 0x68, 0x4a, 0xa8, 0x0c},
}
type wtFwpmFlags uint32
const cFWP_CONDITION_FLAG_IS_LOOPBACK wtFwpmFlags = 0x00000001
// Defined in fwpmtypes.h
type wtFwpmFilterFlags uint32
const (
cFWPM_FILTER_FLAG_NONE wtFwpmFilterFlags = 0x00000000
cFWPM_FILTER_FLAG_PERSISTENT wtFwpmFilterFlags = 0x00000001
cFWPM_FILTER_FLAG_BOOTTIME wtFwpmFilterFlags = 0x00000002
cFWPM_FILTER_FLAG_HAS_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000004
cFWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT wtFwpmFilterFlags = 0x00000008
cFWPM_FILTER_FLAG_PERMIT_IF_CALLOUT_UNREGISTERED wtFwpmFilterFlags = 0x00000010
cFWPM_FILTER_FLAG_DISABLED wtFwpmFilterFlags = 0x00000020
cFWPM_FILTER_FLAG_INDEXED wtFwpmFilterFlags = 0x00000040
cFWPM_FILTER_FLAG_HAS_SECURITY_REALM_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000080
cFWPM_FILTER_FLAG_SYSTEMOS_ONLY wtFwpmFilterFlags = 0x00000100
cFWPM_FILTER_FLAG_GAMEOS_ONLY wtFwpmFilterFlags = 0x00000200
cFWPM_FILTER_FLAG_SILENT_MODE wtFwpmFilterFlags = 0x00000400
cFWPM_FILTER_FLAG_IPSEC_NO_ACQUIRE_INITIATE wtFwpmFilterFlags = 0x00000800
)
// FWPM_LAYER_ALE_AUTH_CONNECT_V4 (c38d57d1-05a7-4c33-904f-7fbceee60e82) defined in fwpmu.h
var cFWPM_LAYER_ALE_AUTH_CONNECT_V4 = windows.GUID{
Data1: 0xc38d57d1,
Data2: 0x05a7,
Data3: 0x4c33,
Data4: [8]byte{0x90, 0x4f, 0x7f, 0xbc, 0xee, 0xe6, 0x0e, 0x82},
}
// e1cd9fe7-f4b5-4273-96c0-592e487b8650
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = windows.GUID{
Data1: 0xe1cd9fe7,
Data2: 0xf4b5,
Data3: 0x4273,
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
}
// FWPM_LAYER_ALE_AUTH_CONNECT_V6 (4a72393b-319f-44bc-84c3-ba54dcb3b6b4) defined in fwpmu.h
var cFWPM_LAYER_ALE_AUTH_CONNECT_V6 = windows.GUID{
Data1: 0x4a72393b,
Data2: 0x319f,
Data3: 0x44bc,
Data4: [8]byte{0x84, 0xc3, 0xba, 0x54, 0xdc, 0xb3, 0xb6, 0xb4},
}
// a3b42c97-9f04-4672-b87e-cee9c483257f
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = windows.GUID{
Data1: 0xa3b42c97,
Data2: 0x9f04,
Data3: 0x4672,
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
}
// 94c44912-9d6f-4ebf-b995-05ab8a088d1b
var cFWPM_LAYER_OUTBOUND_MAC_FRAME_NATIVE = windows.GUID{
Data1: 0x94c44912,
Data2: 0x9d6f,
Data3: 0x4ebf,
Data4: [8]byte{0xb9, 0x95, 0x05, 0xab, 0x8a, 0x08, 0x8d, 0x1b},
}
// d4220bd3-62ce-4f08-ae88-b56e8526df50
var cFWPM_LAYER_INBOUND_MAC_FRAME_NATIVE = windows.GUID{
Data1: 0xd4220bd3,
Data2: 0x62ce,
Data3: 0x4f08,
Data4: [8]byte{0xae, 0x88, 0xb5, 0x6e, 0x85, 0x26, 0xdf, 0x50},
}
// FWP_BITMAP_ARRAY64 defined in fwtypes.h
type wtFwpBitmapArray64 struct {
bitmapArray64 [8]uint8 // Windows type: [8]UINT8
}
// FWP_BYTE_ARRAY6 defined in fwtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array6_)
type wtFwpByteArray6 struct {
byteArray6 [6]uint8 // Windows type: [6]UINT8
}
// FWP_BYTE_ARRAY16 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array16_)
type wtFwpByteArray16 struct {
byteArray16 [16]uint8 // Windows type [16]UINT8
}
// FWP_CONDITION_VALUE0 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_condition_value0).
type wtFwpConditionValue0 wtFwpValue0
// FWP_DATA_TYPE defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_data_type_)
type wtFwpDataType uint
const (
cFWP_EMPTY wtFwpDataType = 0
cFWP_UINT8 wtFwpDataType = cFWP_EMPTY + 1
cFWP_UINT16 wtFwpDataType = cFWP_UINT8 + 1
cFWP_UINT32 wtFwpDataType = cFWP_UINT16 + 1
cFWP_UINT64 wtFwpDataType = cFWP_UINT32 + 1
cFWP_INT8 wtFwpDataType = cFWP_UINT64 + 1
cFWP_INT16 wtFwpDataType = cFWP_INT8 + 1
cFWP_INT32 wtFwpDataType = cFWP_INT16 + 1
cFWP_INT64 wtFwpDataType = cFWP_INT32 + 1
cFWP_FLOAT wtFwpDataType = cFWP_INT64 + 1
cFWP_DOUBLE wtFwpDataType = cFWP_FLOAT + 1
cFWP_BYTE_ARRAY16_TYPE wtFwpDataType = cFWP_DOUBLE + 1
cFWP_BYTE_BLOB_TYPE wtFwpDataType = cFWP_BYTE_ARRAY16_TYPE + 1
cFWP_SID wtFwpDataType = cFWP_BYTE_BLOB_TYPE + 1
cFWP_SECURITY_DESCRIPTOR_TYPE wtFwpDataType = cFWP_SID + 1
cFWP_TOKEN_INFORMATION_TYPE wtFwpDataType = cFWP_SECURITY_DESCRIPTOR_TYPE + 1
cFWP_TOKEN_ACCESS_INFORMATION_TYPE wtFwpDataType = cFWP_TOKEN_INFORMATION_TYPE + 1
cFWP_UNICODE_STRING_TYPE wtFwpDataType = cFWP_TOKEN_ACCESS_INFORMATION_TYPE + 1
cFWP_BYTE_ARRAY6_TYPE wtFwpDataType = cFWP_UNICODE_STRING_TYPE + 1
cFWP_BITMAP_INDEX_TYPE wtFwpDataType = cFWP_BYTE_ARRAY6_TYPE + 1
cFWP_BITMAP_ARRAY64_TYPE wtFwpDataType = cFWP_BITMAP_INDEX_TYPE + 1
cFWP_SINGLE_DATA_TYPE_MAX wtFwpDataType = 0xff
cFWP_V4_ADDR_MASK wtFwpDataType = cFWP_SINGLE_DATA_TYPE_MAX + 1
cFWP_V6_ADDR_MASK wtFwpDataType = cFWP_V4_ADDR_MASK + 1
cFWP_RANGE_TYPE wtFwpDataType = cFWP_V6_ADDR_MASK + 1
cFWP_DATA_TYPE_MAX wtFwpDataType = cFWP_RANGE_TYPE + 1
)
// FWP_V4_ADDR_AND_MASK defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v4_addr_and_mask).
type wtFwpV4AddrAndMask struct {
addr uint32
mask uint32
}
// FWP_V6_ADDR_AND_MASK defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v6_addr_and_mask).
type wtFwpV6AddrAndMask struct {
addr [16]uint8
prefixLength uint8
}
// FWP_VALUE0 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_value0_)
type wtFwpValue0 struct {
_type wtFwpDataType
value uintptr
}
// FWPM_DISPLAY_DATA0 defined in fwptypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwpm_display_data0).
type wtFwpmDisplayData0 struct {
name *uint16 // Windows type: *wchar_t
description *uint16 // Windows type: *wchar_t
}
// FWPM_FILTER_CONDITION0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter_condition0).
type wtFwpmFilterCondition0 struct {
fieldKey windows.GUID // Windows type: GUID
matchType wtFwpMatchType
conditionValue wtFwpConditionValue0
}
// FWPM_PROVIDER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0_)
type wtFwpProvider0 struct {
providerKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags uint32
providerData wtFwpByteBlob
serviceName *uint16 // Windows type: *wchar_t
}
type wtFwpmSessionFlagsValue uint32
const (
cFWPM_SESSION_FLAG_DYNAMIC wtFwpmSessionFlagsValue = 0x00000001 // FWPM_SESSION_FLAG_DYNAMIC defined in fwpmtypes.h
)
// FWPM_SESSION0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_session0).
type wtFwpmSession0 struct {
sessionKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmSessionFlagsValue // Windows type UINT32
txnWaitTimeoutInMSec uint32
processId uint32 // Windows type: DWORD
sid *windows.SID
username *uint16 // Windows type: *wchar_t
kernelMode uint8 // Windows type: BOOL
}
type wtFwpmSublayerFlags uint32
const (
cFWPM_SUBLAYER_FLAG_PERSISTENT wtFwpmSublayerFlags = 0x00000001 // FWPM_SUBLAYER_FLAG_PERSISTENT defined in fwpmtypes.h
)
// FWPM_SUBLAYER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_sublayer0_)
type wtFwpmSublayer0 struct {
subLayerKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmSublayerFlags
providerKey *windows.GUID // Windows type: *GUID
providerData wtFwpByteBlob
weight uint16
}
// Defined in rpcdce.h
type wtRpcCAuthN uint32
const (
cRPC_C_AUTHN_NONE wtRpcCAuthN = 0
cRPC_C_AUTHN_WINNT wtRpcCAuthN = 10
cRPC_C_AUTHN_DEFAULT wtRpcCAuthN = 0xFFFFFFFF
)
// FWPM_PROVIDER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/sv-se/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0).
type wtFwpmProvider0 struct {
providerKey windows.GUID
displayData wtFwpmDisplayData0
flags uint32
providerData wtFwpByteBlob
serviceName *uint16
}
type wtIPProto uint32
const (
cIPPROTO_ICMP wtIPProto = 1
cIPPROTO_ICMPV6 wtIPProto = 58
cIPPROTO_TCP wtIPProto = 6
cIPPROTO_UDP wtIPProto = 17
)
const (
cFWP_ACTRL_MATCH_FILTER = 1
)

View File

@@ -0,0 +1,92 @@
//go:build windows && (386 || arm)
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/types_windows_32.go.
*/
package dnsfw
import "golang.org/x/sys/windows"
const (
wtFwpByteBlob_Size = 8
wtFwpByteBlob_data_Offset = 4
wtFwpConditionValue0_Size = 8
wtFwpConditionValue0_uint8_Offset = 4
wtFwpmDisplayData0_Size = 8
wtFwpmDisplayData0_description_Offset = 4
wtFwpmFilter0_Size = 152
wtFwpmFilter0_displayData_Offset = 16
wtFwpmFilter0_flags_Offset = 24
wtFwpmFilter0_providerKey_Offset = 28
wtFwpmFilter0_providerData_Offset = 32
wtFwpmFilter0_layerKey_Offset = 40
wtFwpmFilter0_subLayerKey_Offset = 56
wtFwpmFilter0_weight_Offset = 72
wtFwpmFilter0_numFilterConditions_Offset = 80
wtFwpmFilter0_filterCondition_Offset = 84
wtFwpmFilter0_action_Offset = 88
wtFwpmFilter0_providerContextKey_Offset = 112
wtFwpmFilter0_reserved_Offset = 128
wtFwpmFilter0_filterID_Offset = 136
wtFwpmFilter0_effectiveWeight_Offset = 144
wtFwpmFilterCondition0_Size = 28
wtFwpmFilterCondition0_matchType_Offset = 16
wtFwpmFilterCondition0_conditionValue_Offset = 20
wtFwpmSession0_Size = 48
wtFwpmSession0_displayData_Offset = 16
wtFwpmSession0_flags_Offset = 24
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 28
wtFwpmSession0_processId_Offset = 32
wtFwpmSession0_sid_Offset = 36
wtFwpmSession0_username_Offset = 40
wtFwpmSession0_kernelMode_Offset = 44
wtFwpmSublayer0_Size = 44
wtFwpmSublayer0_displayData_Offset = 16
wtFwpmSublayer0_flags_Offset = 24
wtFwpmSublayer0_providerKey_Offset = 28
wtFwpmSublayer0_providerData_Offset = 32
wtFwpmSublayer0_weight_Offset = 40
wtFwpProvider0_Size = 40
wtFwpProvider0_displayData_Offset = 16
wtFwpProvider0_flags_Offset = 24
wtFwpProvider0_providerData_Offset = 28
wtFwpProvider0_serviceName_Offset = 36
wtFwpTokenInformation_Size = 16
wtFwpValue0_Size = 8
wtFwpValue0_value_Offset = 4
)
// FWPM_FILTER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
type wtFwpmFilter0 struct {
filterKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmFilterFlags
providerKey *windows.GUID // Windows type: *GUID
providerData wtFwpByteBlob
layerKey windows.GUID // Windows type: GUID
subLayerKey windows.GUID // Windows type: GUID
weight wtFwpValue0
numFilterConditions uint32
filterCondition *wtFwpmFilterCondition0
action wtFwpmAction0
offset1 [4]byte // Layout correction field
providerContextKey windows.GUID // Windows type: GUID
reserved *windows.GUID // Windows type: *GUID
offset2 [4]byte // Layout correction field
filterID uint64
effectiveWeight wtFwpValue0
}

View File

@@ -0,0 +1,89 @@
//go:build windows && (amd64 || arm64)
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*
* Adapted from wireguard-windows tunnel/firewall/types_windows_64.go.
*/
package dnsfw
import "golang.org/x/sys/windows"
const (
wtFwpByteBlob_Size = 16
wtFwpByteBlob_data_Offset = 8
wtFwpConditionValue0_Size = 16
wtFwpConditionValue0_uint8_Offset = 8
wtFwpmDisplayData0_Size = 16
wtFwpmDisplayData0_description_Offset = 8
wtFwpmFilter0_Size = 200
wtFwpmFilter0_displayData_Offset = 16
wtFwpmFilter0_flags_Offset = 32
wtFwpmFilter0_providerKey_Offset = 40
wtFwpmFilter0_providerData_Offset = 48
wtFwpmFilter0_layerKey_Offset = 64
wtFwpmFilter0_subLayerKey_Offset = 80
wtFwpmFilter0_weight_Offset = 96
wtFwpmFilter0_numFilterConditions_Offset = 112
wtFwpmFilter0_filterCondition_Offset = 120
wtFwpmFilter0_action_Offset = 128
wtFwpmFilter0_providerContextKey_Offset = 152
wtFwpmFilter0_reserved_Offset = 168
wtFwpmFilter0_filterID_Offset = 176
wtFwpmFilter0_effectiveWeight_Offset = 184
wtFwpmFilterCondition0_Size = 40
wtFwpmFilterCondition0_matchType_Offset = 16
wtFwpmFilterCondition0_conditionValue_Offset = 24
wtFwpmSession0_Size = 72
wtFwpmSession0_displayData_Offset = 16
wtFwpmSession0_flags_Offset = 32
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 36
wtFwpmSession0_processId_Offset = 40
wtFwpmSession0_sid_Offset = 48
wtFwpmSession0_username_Offset = 56
wtFwpmSession0_kernelMode_Offset = 64
wtFwpmSublayer0_Size = 72
wtFwpmSublayer0_displayData_Offset = 16
wtFwpmSublayer0_flags_Offset = 32
wtFwpmSublayer0_providerKey_Offset = 40
wtFwpmSublayer0_providerData_Offset = 48
wtFwpmSublayer0_weight_Offset = 64
wtFwpProvider0_Size = 64
wtFwpProvider0_displayData_Offset = 16
wtFwpProvider0_flags_Offset = 32
wtFwpProvider0_providerData_Offset = 40
wtFwpProvider0_serviceName_Offset = 56
wtFwpValue0_Size = 16
wtFwpValue0_value_Offset = 8
)
// FWPM_FILTER0 defined in fwpmtypes.h
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
type wtFwpmFilter0 struct {
filterKey windows.GUID // Windows type: GUID
displayData wtFwpmDisplayData0
flags wtFwpmFilterFlags // Windows type: UINT32
providerKey *windows.GUID // Windows type: *GUID
providerData wtFwpByteBlob
layerKey windows.GUID // Windows type: GUID
subLayerKey windows.GUID // Windows type: GUID
weight wtFwpValue0
numFilterConditions uint32
filterCondition *wtFwpmFilterCondition0
action wtFwpmAction0
offset1 [4]byte // Layout correction field
providerContextKey windows.GUID // Windows type: GUID
reserved *windows.GUID // Windows type: *GUID
filterID uint64
effectiveWeight wtFwpValue0
}

View File

@@ -0,0 +1,130 @@
// Code generated by 'go generate'; DO NOT EDIT.
package dnsfw
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0")
procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0")
procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0")
procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0")
)
func fwpmEngineClose0(engineHandle uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmFreeMemory0(p unsafe.Pointer) {
syscall.Syscall(procFwpmFreeMemory0.Addr(), 1, uintptr(p), 0, 0)
return
}
func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmTransactionAbort0(engineHandle uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}
func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
if r1 != 0 {
err = errnoErr(e1)
}
return
}

View File

@@ -16,6 +16,10 @@ type hostManager interface {
restoreHostDNS() error
supportCustomPort() bool
string() string
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
// hardcoded Quad9 on iOS, nil for noop / mock.
getOriginalNameservers() []netip.Addr
}
type SystemDNSSettings struct {
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
func (n noopHostConfigurator) string() string {
return "noop"
}
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}

View File

@@ -1,14 +1,20 @@
package dns
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// androidHostManager is a noop on the OS side (Android's VPN service handles
// DNS for us) but tracks the OS-reported resolver list pushed via
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
type androidHostManager struct {
holder *hostsDNSHolder
}
func newHostManager() (*androidHostManager, error) {
return &androidHostManager{}, nil
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
return &androidHostManager{holder: holder}, nil
}
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
@@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool {
func (a androidHostManager) string() string {
return "none"
}
func (a androidHostManager) getOriginalNameservers() []netip.Addr {
hosts := a.holder.get()
out := make([]netip.Addr, 0, len(hosts))
for ap := range hosts {
out = append(out, ap.Addr())
}
return out
}

View File

@@ -3,6 +3,7 @@ package dns
import (
"encoding/json"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
@@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
}, nil
}
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
return []netip.Addr{
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
}
}
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
jsonData, err := json.Marshal(config)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strings"
"syscall"
"time"
@@ -16,6 +17,7 @@ import (
"golang.org/x/sys/windows/registry"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/winregistry"
)
@@ -44,9 +46,11 @@ const (
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
@@ -67,10 +71,12 @@ const (
)
type registryConfigurator struct {
guid string
routingAll bool
gpo bool
nrptEntryCount int
guid string
routingAll bool
gpo bool
nrptEntryCount int
dnsFirewall dnsfw.Manager
origNameservers []netip.Addr
}
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -90,10 +96,22 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
}
configurator := &registryConfigurator{
guid: guid,
gpo: useGPO,
guid: guid,
gpo: useGPO,
dnsFirewall: dnsfw.New(),
}
origNameservers, err := configurator.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
}
configurator.origNameservers = origNameservers
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
@@ -101,6 +119,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
return configurator, nil
}
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
// registry key except the WG adapter. v4 and v6 servers live in separate
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
var merr *multierror.Error
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
addrs, err := r.captureFromTcpipRoot(root)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
continue
}
for _, addr := range addrs {
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
if err != nil {
return nil, fmt.Errorf("open key: %w", err)
}
defer closer(root)
guids, err := root.ReadSubKeyNames(-1)
if err != nil {
return nil, fmt.Errorf("read subkeys: %w", err)
}
var out []netip.Addr
for _, guid := range guids {
if strings.EqualFold(guid, r.guid) {
continue
}
out = append(out, readInterfaceNameservers(rootPath, guid)...)
}
return out, nil
}
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
keyPath := rootPath + "\\" + guid
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
if err != nil {
return nil
}
defer closer(k)
// Static NameServer wins over DhcpNameServer for actual resolution.
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
raw, _, err := k.GetStringValue(name)
if err != nil || raw == "" {
continue
}
if out := parseRegistryNameservers(raw); len(out) > 0 {
return out
}
}
return nil
}
func parseRegistryNameservers(raw string) []netip.Addr {
var out []netip.Addr
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
addr, err := netip.ParseAddr(strings.TrimSpace(field))
if err != nil {
continue
}
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// Drop unzoned link-local: not routable without a scope id. If
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
continue
}
out = append(out, addr)
}
return out
}
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(r.origNameservers)
}
func (r *registryConfigurator) supportCustomPort() bool {
return false
}
@@ -169,16 +279,8 @@ func (r *registryConfigurator) disableWINSForInterface() error {
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if config.RouteAll {
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
} else if r.routingAll {
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
return fmt.Errorf("delete interface registry key property: %w", err)
}
r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
if err := r.applyRouteAll(config); err != nil {
return err
}
r.updateState(stateManager)
@@ -220,6 +322,35 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil
}
func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error {
if config.RouteAll {
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
return fmt.Errorf("dns firewall: %w", err)
}
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
merr := multierror.Append(nil, fmt.Errorf("add dns setup: %w", err))
if dErr := r.dnsFirewall.Disable(); dErr != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback dns firewall: %w", dErr))
}
return nberrors.FormatErrorOrNil(merr)
}
return nil
}
if err := r.dnsFirewall.Disable(); err != nil {
log.Errorf("disable dns firewall: %v", err)
}
if !r.routingAll {
return nil
}
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
return fmt.Errorf("delete interface registry key property: %w", err)
}
r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
return nil
}
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
@@ -406,6 +537,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err)
}
if err := r.dnsFirewall.Disable(); err != nil {
log.Errorf("disable dns firewall: %v", err)
}
go r.flushDNSCache()
return nil

View File

@@ -8,6 +8,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
)
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
@@ -34,8 +36,9 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
guid: testGUID,
gpo: false,
dnsFirewall: dnsfw.New(),
}
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
@@ -134,8 +137,9 @@ func TestNRPTDomainBatching(t *testing.T) {
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
guid: testGUID,
gpo: false,
dnsFirewall: dnsfw.New(),
}
testCases := []struct {

View File

@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Unlock()
}
//nolint:unused
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock()
l := h.unprotectedDNSList

View File

@@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability(context.Context) {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{

View File

@@ -9,6 +9,7 @@ import (
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
return make([]string, 0)
}
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() {
}
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// SetRouteSources mock implementation of SetRouteSources from Server interface
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
// Mock implementation - no-op
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"time"
@@ -32,6 +33,15 @@ const (
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
networkManagerDbusIPv4Key = "ipv4"
networkManagerDbusIPv6Key = "ipv6"
@@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{
}
type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
origNameservers []netip.Addr
}
// the types below are based on dbus specification, each field is mapped to a dbus type
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
return &networkManagerDbusConfigurator{
c := &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from NetworkManager: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads DNS servers from every NM device's
// IP4Config / IP6Config except our WG device.
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
devices, err := networkManagerListDevices()
if err != nil {
return nil, fmt.Errorf("list devices: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, dev := range devices {
if dev == n.dbusLinkObject {
continue
}
ifaceName := readNetworkManagerDeviceInterface(dev)
for _, addr := range readNetworkManagerDeviceDNS(dev) {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// IP6Config.Nameservers is a byte slice without zone info;
// reattach the device's interface name so a captured fe80::…
// stays routable.
if addr.IsLinkLocalUnicast() && ifaceName != "" {
addr = addr.WithZone(ifaceName)
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return ""
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
if err != nil {
return ""
}
s, _ := v.Value().(string)
return s
}
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
}
defer closeConn()
var devs []dbus.ObjectPath
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
return nil, err
}
return devs, nil
}
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return nil
}
defer closeConn()
var out []netip.Addr
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
out = append(out, readIPv4ConfigDNS(path)...)
}
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
out = append(out, readIPv6ConfigDNS(path)...)
}
return out
}
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
v, err := obj.GetProperty(property)
if err != nil {
return ""
}
path, ok := v.Value().(dbus.ObjectPath)
if !ok || path == "/" {
return ""
}
return path
}
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
// legacy uint32 Nameservers property.
if out := readIPv4NameserverData(obj); len(out) > 0 {
return out
}
return readIPv4LegacyNameservers(obj)
}
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([]map[string]dbus.Variant)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
addrVar, ok := entry["address"]
if !ok {
continue
}
s, ok := addrVar.Value().(string)
if !ok {
continue
}
if a, err := netip.ParseAddr(s); err == nil {
out = append(out, a)
}
}
return out
}
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([]uint32)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, n := range raw {
var b [4]byte
binary.LittleEndian.PutUint32(b[:], n)
out = append(out, netip.AddrFrom4(b))
}
return out
}
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([][]byte)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, b := range raw {
if a, ok := netip.AddrFromSlice(b); ok {
out = append(out, a)
}
}
return out
}
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(n.origNameservers)
}
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager()
return newHostManager(s.hostsDNSHolder)
}

View File

@@ -6,7 +6,7 @@ import (
"net"
"net/netip"
"os"
"strings"
"runtime"
"testing"
"time"
@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -31,8 +32,10 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -101,16 +104,17 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
return &upstreamResolverBase{
domain: domain,
upstreamServers: srvs,
cancel: func() {},
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
}
}
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: hostManager,
currentConfig: HostDNSConfig{
Domains: []DomainConfig{
{false, "domain0", false},
{false, "domain1", false},
{false, "domain2", false},
},
},
statusRecorder: peer.NewRecorder("mgm"),
}
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{}
for _, item := range config.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
domainsUpdate = strings.Join(domains, ",")
return nil
}
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"},
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
},
}, nil, 0)
deactivate(nil)
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got := strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got)
}
reactivate()
expected = "domain0,domain1,domain2"
domains = []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got = strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
}
}
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
func TestDNSPermanent_matchOnly(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
}
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
// + androidHostManager). On non-android the desktop host manager replaces it
// during Initialize and the assertion stops making sense. Skipped here until we
// have an android CI runner.
func skipUnlessAndroid(t *testing.T) {
t.Helper()
if runtime.GOOS != "android" {
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
t.Helper()
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
@@ -1065,7 +1017,6 @@ type mockHandler struct {
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
// admin-defined nameserver groups targeting the same domain collapse into a
// single handler with each group preserved as a sequential inner list.
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgInterface,
service: service,
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
groups := []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
}
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
require.NoError(t, err)
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
assert.Equal(t, "example.com", muxUpdates[0].domain)
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
handler := muxUpdates[0].handler.(*upstreamResolver)
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
assert.Equal(t, upstreamRace{
netip.MustParseAddrPort("192.0.2.2:53"),
netip.MustParseAddrPort("192.0.2.3:53"),
}, handler.upstreamServers[1])
}
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
// (overlay route selected-but-no-active-peer) is intentionally NOT an
// input to the evaluator anymore: the verdict drives the Enabled flag,
// which must always reflect what we actually observed. Gate-aware event
// suppression is tested separately in the projection test.
//
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
// fresh-working → Unhealthy; otherwise Undecided.
func TestEvaluateNSGroupHealth(t *testing.T) {
now := time.Now()
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
okThenFail := UpstreamHealth{
LastOk: now.Add(-10 * time.Second),
LastFail: now.Add(-1 * time.Second),
LastErr: "timeout",
}
failThenOk := UpstreamHealth{
LastOk: now.Add(-1 * time.Second),
LastFail: now.Add(-10 * time.Second),
LastErr: "timeout",
}
tests := []struct {
name string
health map[netip.AddrPort]UpstreamHealth
servers []netip.AddrPort
wantVerdict nsGroupVerdict
wantErrSubst string
}{
{
name: "no record, undecided",
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "fresh success, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "fresh failure, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "only stale success, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "only stale failure, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "both fresh, fail newer, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "both fresh, ok newer, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one success wins",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
b: recentOk,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one fail one unseen, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "two upstreams, all recent failures, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "SERVFAIL",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
if tc.wantErrSubst != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantErrSubst)
} else {
assert.NoError(t, err)
}
})
}
}
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers.
type healthStubHandler struct {
health map[netip.AddrPort]UpstreamHealth
}
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (h *healthStubHandler) Stop() {}
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
return h.health
}
// TestProjection_SteadyStateIsSilent guards against duplicate events:
// while a group stays Unhealthy tick after tick, only the first
// Unhealthy transition may emit. Same for staying Healthy.
func TestProjection_SteadyStateIsSilent(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "first fail emits warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.tick()
fx.expectNoEvent("staying unhealthy must not re-emit")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery on transition")
fx.tick()
fx.tick()
fx.expectNoEvent("staying healthy must not re-emit")
}
// projTestFixture is the common setup for the projection tests: a
// single-upstream group whose route classification the test can flip by
// assigning to selected/active. Callers drive failures/successes by
// mutating stub.health and calling refreshHealth.
type projTestFixture struct {
t *testing.T
recorder *peer.Status
events <-chan *proto.SystemEvent
server *DefaultServer
stub *healthStubHandler
group *nbdns.NameServerGroup
srv netip.AddrPort
selected route.HAMap
active route.HAMap
}
func newProjTestFixture(t *testing.T) *projTestFixture {
t.Helper()
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
srv := netip.MustParseAddrPort("100.64.0.1:53")
fx := &projTestFixture{
t: t,
recorder: recorder,
events: sub.Events(),
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
srv: srv,
group: &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
},
}
fx.server = &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase,
}
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
fx.server.mux.Unlock()
return fx
}
func (f *projTestFixture) setHealth(h UpstreamHealth) {
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
}
func (f *projTestFixture) tick() []peer.NSGroupState {
f.server.refreshHealth()
return f.recorder.GetDNSStates()
}
func (f *projTestFixture) expectNoEvent(why string) {
f.t.Helper()
select {
case evt := <-f.events:
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
case <-time.After(100 * time.Millisecond):
}
}
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
f.t.Helper()
select {
case evt := <-f.events:
assert.Contains(f.t, evt.Message, substr, why)
return evt
case <-time.After(time.Second):
f.t.Fatalf("expected event (%s) with %q", why, substr)
return nil
}
}
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
// that is not inside any selected route (public DNS) fires the warning
// on the first Unhealthy tick, no grace period.
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "public DNS failure")
}
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
// the upstream is inside a selected route AND the route has a Connected
// peer. Tunnel is up, failure is real, emit immediately.
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "overlay + connected failure")
}
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
// First tick: Unhealthy display, no warning. After the grace window
// elapses with no recovery, the warning fires.
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
grace := 50 * time.Millisecond
fx := newProjTestFixture(t)
fx.server.warningDelayBase = grace
fx.selected = overlayMapForTest
// active stays nil: routed but not connected.
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
fx.expectNoEvent("first fail tick within grace window")
time.Sleep(grace + 10*time.Millisecond)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "warning after grace window")
}
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
// whose address is inside the WireGuard overlay range but is not
// covered by any selected route (peer-to-peer DNS without an explicit
// route). Until a peer reports Connected for that address, startup
// failures must be held just like the routed case.
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-sub.Events():
t.Fatalf("unexpected event during grace window: %+v", evt)
case <-time.After(100 * time.Millisecond):
}
time.Sleep(60 * time.Millisecond)
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
server.refreshHealth()
select {
case evt := <-sub.Events():
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected warning after grace window")
}
}
// TestProjection_StopClearsHealthState verifies that Stop wipes the
// per-group projection state so a subsequent Start doesn't inherit
// sticky flags (notably everHealthy) that would bypass the grace
// window during the next peer handshake.
func TestProjection_StopClearsHealthState(t *testing.T) {
wgIface := &mocWGIface{}
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgIface,
service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: defaultWarningDelayBase,
currentConfigHash: ^uint64(0),
}
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
srv := netip.MustParseAddrPort("8.8.8.8:53")
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
server.healthProjectMu.Lock()
p, ok := server.nsGroupProj[generateGroupKey(group)]
server.healthProjectMu.Unlock()
require.True(t, ok, "projection state should exist after tick")
require.True(t, p.everHealthy, "tick with success must set everHealthy")
server.Stop()
server.healthProjectMu.Lock()
cleared := server.nsGroupProj == nil
server.healthProjectMu.Unlock()
assert.True(t, cleared, "Stop must clear nsGroupProj")
}
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
// rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either.
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond
fx.selected = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectNoEvent("fail within grace, warning suppressed")
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("recovery without prior warning must not emit")
}
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
// whole design leans on: recovery events only appear when a warning
// event was actually emitted for the current streak. A Healthy verdict
// without a prior warning is silent, so the user never sees "recovered"
// out of thin air.
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("first healthy tick should not recover anything")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "public fail emits immediately")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery follows real warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "second cycle warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "second cycle recovery")
}
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
// has ever been Healthy, subsequent failures skip the grace window even
// if classification says "routed + not connected". The system has
// proved it can work, so any new failure is real.
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
fx := newProjTestFixture(t)
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
fx.server.warningDelayBase = time.Hour
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
// Establish "ever healthy".
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectNoEvent("first healthy tick")
// Peer drops. Query fails. Routed + not connected → normally grace,
// but everHealthy flag bypasses it.
fx.active = nil
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
}
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
// from the design discussion: once a group has been healthy, a brief
// reconnect that produces a failing tick will fire warning + recovery.
// This is by design: user-visible blips are accurate signal, not noise.
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "blip warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "blip recovery")
}
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
// rule: a group with at least one public upstream is in the "immediate"
// category regardless of the other upstreams' routing, because the
// public one has no peer-startup excuse. Prevents public-DNS failures
// from being hidden behind a routed sibling.
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
events := sub.Events()
public := netip.MustParseAddrPort("8.8.8.8:53")
overlay := netip.MustParseAddrPort("100.64.0.1:53")
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
server := &DefaultServer{
ctx: context.Background(),
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
},
}
stub := &healthStubHandler{
health: map[netip.AddrPort]UpstreamHealth{
public: {LastFail: time.Now(), LastErr: "servfail"},
overlay: {LastFail: time.Now(), LastErr: "timeout"},
},
}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-events:
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected immediate warning because group contains a public upstream")
}
}
func TestDNSLoopPrevention(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
if tt.expectedHandlers > 0 {
handler := muxUpdates[0].handler.(*upstreamResolver)
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
flat := handler.flatUpstreams()
assert.Len(t, flat, len(tt.expectedServers))
if tt.shouldFilterOwnIP {
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
assert.NotEqual(t, dnsServerIP, upstream.Addr())
}
}
for _, expected := range tt.expectedServers {
found := false
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
if upstream.Addr() == expected {
found = true
break

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"
"github.com/godbus/dbus/v5"
@@ -40,10 +41,17 @@ const (
)
type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
ifaceName string
dbusLinkObject dbus.ObjectPath
ifaceName string
wgIndex int
origNameservers []netip.Addr
}
const (
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
)
// the types below are based on dbus specification, each field is mapped to a dbus type
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
return &systemdDbusConfigurator{
c := &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
wgIndex: iface.Index,
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
// every default-route link except our own WG link. Non-default-route links
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
// actually serve host queries.
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("list interfaces: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, iface := range ifaces {
if !s.isCandidateLink(iface) {
continue
}
linkPath, err := getSystemdLinkPath(iface.Index)
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
continue
}
for _, addr := range readSystemdLinkDNS(linkPath) {
addr = normalizeSystemdAddr(addr, iface.Name)
if !addr.IsValid() {
continue
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
if iface.Index == s.wgIndex {
return false
}
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
return false
}
return true
}
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
// Returns the zero Addr to signal "skip this entry".
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
return netip.Addr{}
}
if addr.IsLinkLocalUnicast() {
return addr.WithZone(ifaceName)
}
return addr
}
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil {
return "", fmt.Errorf("dbus resolve1: %w", err)
}
defer closeConn()
var p string
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
return "", err
}
return dbus.ObjectPath(p), nil
}
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return false
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
if err != nil {
return false
}
b, ok := v.Value().(bool)
return ok && b
}
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([][]any)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
if len(entry) < 2 {
continue
}
raw, ok := entry[1].([]byte)
if !ok {
continue
}
addr, ok := netip.AddrFromSlice(raw)
if !ok {
continue
}
out = append(out, addr)
}
return out
}
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(s.origNameservers)
}
func (s *systemdDbusConfigurator) supportCustomPort() bool {

View File

@@ -1,3 +1,32 @@
// Package dns implements the client-side DNS stack: listener/service on the
// peer's tunnel address, handler chain that routes questions by domain and
// priority, and upstream resolvers that forward what remains to configured
// nameservers.
//
// # Upstream resolution and the race model
//
// When two or more nameserver groups target the same domain, DefaultServer
// merges them into one upstream handler whose state is:
//
// upstreamResolverBase
// └── upstreamServers []upstreamRace // one entry per source NS group
// └── []netip.AddrPort // primary, fallback, ...
//
// Each source nameserver group contributes one upstreamRace. Within a race
// upstreams are tried in order: the next is used only on failure (timeout,
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
// the walk. When more than one race exists, ServeDNS fans out one
// goroutine per race and returns the first valid answer, cancelling the
// rest. A handler with a single race skips the fan-out.
//
// # Health projection
//
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
// periodically merges these snapshots across handlers and projects them
// into peer.NSGroupState. There is no active probing: a group is marked
// unhealthy only when every seen upstream has a recent failure and none
// has a recent success. Healthy→unhealthy fires a single
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
package dns
import (
@@ -11,11 +40,8 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -25,7 +51,8 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var currentMTU uint16 = iface.DefaultMTU
@@ -67,15 +94,17 @@ const (
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 5 * time.Second
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
// within one race, so a slow primary can't eat the whole race budget.
raceMaxTotalTimeout = 5 * time.Second
// raceMinPerUpstreamTimeout is the floor applied when dividing
// raceMaxTotalTimeout across upstreams within a race.
raceMinPerUpstreamTimeout = 2 * time.Second
)
const (
protoUDP = "udp"
@@ -84,6 +113,69 @@ const (
type dnsProtocolKey struct{}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
// upstreamRace is an ordered list of upstreams derived from one configured
// nameserver group. Order matters: the first upstream is tried first, the
// second only on failure, and so on. Multiple upstreamRace values coexist
// inside one resolver when overlapping nameserver groups target the same
// domain; those races run in parallel and the first valid answer wins.
type upstreamRace []netip.AddrPort
// UpstreamHealth is the last query-path outcome for a single upstream,
// consumed by nameserver-group status projection.
type UpstreamHealth struct {
LastOk time.Time
LastFail time.Time
LastErr string
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []upstreamRace
domain domain.Domain
upstreamTimeout time.Duration
healthMu sync.RWMutex
health map[netip.AddrPort]*UpstreamHealth
statusRecorder *peer.Status
// selectedRoutes returns the current set of client routes the admin
// has enabled. Called lazily from the query hot path when an upstream
// might need a tunnel-bound client (iOS) and from health projection.
selectedRoutes func() route.HAMap
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
type raceResult struct {
msg *dns.Msg
upstream netip.AddrPort
protocol string
ede string
failures []upstreamFailure
}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
@@ -100,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
@@ -124,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []netip.AddrPort
domain string
disabled bool
successCount atomic.Int32
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
wg sync.WaitGroup
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
routeMatch func(netip.Addr) bool
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{
ctx: ctx,
cancel: cancel,
domain: domain,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod,
statusRecorder: statusRecorder,
ctx: ctx,
cancel: cancel,
domain: d,
upstreamTimeout: UpstreamTimeout,
statusRecorder: statusRecorder,
}
}
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("Upstream %s", u.upstreamServers)
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
}
// ID returns the unique handler ID
// ID returns the unique handler ID. Race groupings and within-race
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
// the same servers but with different semantics (serial fallback vs
// parallel race), so their handlers must not collide.
func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers)
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
hash.Write([]byte(u.domain.PunycodeString() + ":"))
for _, race := range u.upstreamServers {
hash.Write([]byte("["))
for _, s := range race {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
hash.Write([]byte("]"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
@@ -194,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
}
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
u.cancel()
}
u.mutex.Lock()
u.wg.Wait()
u.mutex.Unlock()
// flatUpstreams is for logging and ID hashing only, not for dispatch.
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
var out []netip.AddrPort
for _, g := range u.upstreamServers {
out = append(out, g...)
}
return out
}
// setSelectedRoutes swaps the accessor used to classify overlay-routed
// upstreams. Called when route sources are wired after the handler was
// built (permanent / iOS constructors).
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
u.selectedRoutes = selected
}
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
if len(servers) == 0 {
return
}
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
}
// ServeDNS handles a DNS request
@@ -242,82 +314,201 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
groups := u.upstreamServers
switch len(groups) {
case 0:
return false, nil
case 1:
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
default:
return u.raceAll(ctx, w, r, groups, logger)
}
}
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
res := u.tryRace(ctx, r, group)
if res.msg == nil {
return false, res.failures
}
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, res.failures
}
// raceAll runs one worker per group in parallel, taking the first valid
// answer and cancelling the rest.
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
raceCtx, cancel := context.WithCancel(ctx)
defer cancel()
// Buffer sized to len(groups) so workers never block on send, even
// after the coordinator has returned.
results := make(chan raceResult, len(groups))
for _, g := range groups {
// tryRace clones the request per attempt, so workers never share
// a *dns.Msg and concurrent EDNS0 mutations can't race.
go func(g upstreamRace) {
results <- u.tryRace(raceCtx, r, g)
}(g)
}
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
for range groups {
select {
case res := <-results:
failures = append(failures, res.failures...)
if res.msg != nil {
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, failures
}
case <-ctx.Done():
return false, failures
}
}
return false, failures
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
timeout := u.upstreamTimeout
if len(group) > 1 {
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
// still honor raceMinPerUpstreamTimeout as a floor for correctness
// on slow links, but the outer context ensures the combined walk
// cannot exceed the cap regardless of group size.
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
defer cancel()
}
var failures []upstreamFailure
for _, upstream := range group {
if ctx.Err() != nil {
return raceResult{failures: failures}
}
// Clone the request per attempt: the exchange path mutates EDNS0
// options in-place, so reusing the same *dns.Msg across sequential
// upstreams would carry those mutations (e.g. a reduced UDP size)
// into the next attempt.
res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
if failure != nil {
failures = append(failures, *failure)
continue
}
res.failures = failures
return res
}
return raceResult{failures: failures}
}
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
// Advertise EDNS0 so the upstream may include Extended DNS Errors
// (RFC 8914) in failure responses; we use those to short-circuit
// failover for definitive answers like DNSSEC validation failures.
// Operate on a copy so the inbound request is unchanged: a client that
// did not advertise EDNS0 must not see an OPT in the response.
// The caller already passed a per-attempt copy, so we can mutate r
// directly; hadEdns reflects the original client request's state and
// controls whether we strip the OPT from the response.
hadEdns := r.IsEdns0() != nil
reqUp := r
if !hadEdns {
reqUp = r.Copy()
reqUp.SetEdns0(upstreamUDPSize(), false)
r.SetEdns0(upstreamUDPSize(), false)
}
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
}()
startTime := time.Now()
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
if err != nil {
return u.handleUpstreamError(err, upstream, startTime)
// A parent cancellation (e.g., another race won and the coordinator
// cancelled the losers) is not an upstream failure. Check both the
// error chain and the parent context: a transport may surface the
// cancellation as a read/deadline error rather than context.Canceled.
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"}
}
failure := u.handleUpstreamError(err, upstream, startTime)
u.markUpstreamFail(upstream, failure.reason)
return raceResult{}, failure
}
if rm == nil || !rm.Response {
return &upstreamFailure{upstream: upstream, reason: "no response"}
u.markUpstreamFail(upstream, "no response")
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
}
proto := ""
if upstreamProto != nil {
proto = upstreamProto.protocol
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
if code, ok := nonRetryableEDE(rm); ok {
resutil.SetMeta(w, "ede", edeName(code))
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
}
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
}
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
}
// healthEntry returns the mutable health record for addr, lazily creating
// the map and the entry. Caller must hold u.healthMu.
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
if u.health == nil {
u.health = make(map[netip.AddrPort]*UpstreamHealth)
}
h := u.health[addr]
if h == nil {
h = &UpstreamHealth{}
u.health[addr] = h
}
return h
}
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastOk = time.Now()
h.LastFail = time.Time{}
h.LastErr = ""
}
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastFail = time.Now()
h.LastErr = reason
}
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
u.healthMu.RLock()
defer u.healthMu.RUnlock()
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
for k, v := range u.health {
out[k] = *v
}
return out
}
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
@@ -358,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
u.successCount.Add(1)
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
if proto != "" {
resutil.SetMeta(w, "upstream_protocol", proto)
}
// Clear Zero bit from external responses to prevent upstream servers from
@@ -372,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
totalUpstreams := len(u.upstreamServers)
totalUpstreams := len(u.flatUpstreams())
failedCount := len(failures)
failureSummary := formatFailures(failures)
@@ -434,119 +633,6 @@ func edeName(code uint16) string {
return fmt.Sprintf("EDE %d", code)
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
u.mutex.Lock()
defer u.mutex.Unlock()
// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
}
var success bool
var mu sync.Mutex
var wg sync.WaitGroup
var errs *multierror.Error
for _, upstream := range u.upstreamServers {
wg.Add(1)
go func(upstream netip.AddrPort) {
defer wg.Done()
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
if err != nil {
mu.Lock()
errs = multierror.Append(errs, err)
mu.Unlock()
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
mu.Lock()
success = true
mu.Unlock()
}(upstream)
}
wg.Wait()
select {
case <-ctx.Done():
return
case <-u.ctx.Done():
return
default:
}
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errs.ErrorOrNil())
if u.statusRecorder == nil {
return
}
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": u.upstreamServersString()},
)
}
}
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
func (u *upstreamResolverBase) waitUntilResponse() {
exponentialBackOff := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.1,
MaxInterval: u.reactivatePeriod,
MaxElapsedTime: 0,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
operation := func() error {
select {
case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default:
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
return nil
}
}
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error")
}
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
if err != nil {
if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
return
}
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.successCount.Add(1)
u.reactivate()
u.mutex.Lock()
u.disabled = false
u.mutex.Unlock()
}
// isTimeout returns true if the given error is a network timeout error.
//
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
@@ -558,45 +644,6 @@ func isTimeout(err error) bool {
return false
}
func (u *upstreamResolverBase) disable(err error) {
if u.disabled {
return
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.waitUntilResponse()
}()
}
func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
defer cancel()
if externalCtx != nil {
stop2 := context.AfterFunc(externalCtx, cancel)
defer stop2()
}
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
@@ -608,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -634,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
opt.SetUDPSize(maxUDPPayload)
}
var (
rm *dns.Msg
t time.Duration
err error
)
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
rm, t, err := client.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with udp: %w", err)
}
@@ -659,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
tcpClient := *client
tcpClient.Net = protoTCP
if ctx == nil {
rm, t, err = tcpClient.Exchange(r, upstream)
} else {
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
}
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -681,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
// so net.Dialer doesn't reject the TCP dial with "mismatched local
// address type".
func toTCPClient(c *dns.Client) *dns.Client {
tcp := *c
tcp.Net = protoTCP
if tcp.Dialer == nil {
return &tcp
}
d := *tcp.Dialer
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
}
tcp.Dialer = &d
return &tcp
}
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
@@ -822,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
return bestMatch
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
// haMapRouteCount returns the total number of routes across all HA
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
// routes per key, so len(hm) is the number of HA groups, not routes.
func haMapRouteCount(hm route.HAMap) int {
total := 0
for _, routes := range hm {
total += len(routes)
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
return total
}
// haMapContains checks whether ip is covered by any concrete prefix in
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
// routes carry a placeholder Network that can't be prefix-checked, so we
// can't know at this point whether ip is reached through one. Callers
// decide how to interpret the unknown: health projection treats it as
// "possibly routed" to avoid emitting false-positive warnings during
// startup, while iOS dial selection requires a concrete match before
// binding to the tunnel.
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
for _, routes := range hm {
for _, r := range routes {
if r.IsDynamic() {
haveDynamic = true
continue
}
if r.Network.Contains(ip) {
return true, haveDynamic
}
}
}
return false, haveDynamic
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -26,9 +27,9 @@ func newUpstreamResolver(
_ WGIface,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder,

View File

@@ -12,6 +12,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -24,9 +25,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolverIOS struct {
@@ -27,9 +28,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
@@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP = upstreamIP.Unmap()
}
addr := u.wgIface.Address()
var routed bool
if u.selectedRoutes != nil {
// Only a concrete prefix match binds to the tunnel: dialing
// through a private client for an upstream we can't prove is
// routed would break public resolvers.
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
}
needsPrivate := addr.Network.Contains(upstreamIP) ||
addr.IPv6Net.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
routed
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
@@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return ExchangeWithFallback(nil, client, r, upstream)
return ExchangeWithFallback(ctx, client, r, upstream)
}
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync/atomic"
"testing"
"time"
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.addRace(servers)
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
cancel()
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration
err error
}
// exchange mock implementation of exchange from upstreamResolver
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
return c.r, c.rtt, c.err
}
type mockUpstreamResponse struct {
msg *dns.Msg
err error
msg *dns.Msg
err error
delay time.Duration
}
type mockUpstreamResolverPerServer struct {
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
rtt time.Duration
}
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
if r, ok := c.responses[upstream]; ok {
return r.msg, c.rtt, r.err
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
r, ok := c.responses[upstream]
if !ok {
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
r: new(dns.Msg),
rtt: time.Millisecond,
}
resolver := &upstreamResolverBase{
ctx: context.TODO(),
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: time.Microsecond * 100,
}
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
failed := false
resolver.deactivate = func(error) {
failed = true
// After deactivation, make the mock client work again
mockClient.err = nil
}
reactivated := false
resolver.reactivate = func() {
reactivated = true
}
resolver.ProbeAvailability(context.TODO())
if !failed {
t.Errorf("expected that resolving was deactivated")
return
}
if !resolver.disabled {
t.Errorf("resolver should be Disabled")
return
}
time.Sleep(time.Millisecond * 200)
if !reactivated {
t.Errorf("expected that resolving was reactivated")
return
}
if resolver.disabled {
t.Errorf("should be enabled")
if r.delay > 0 {
select {
case <-time.After(r.delay):
case <-ctx.Done():
return nil, c.rtt, ctx.Err()
}
}
return r.msg, c.rtt, r.err
}
func TestUpstreamResolver_Failover(t *testing.T) {
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: trackingClient,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamServers: []netip.AddrPort{upstream},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
}
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
// configured for the same domain, with one broken group. The merge+race
// path should answer as fast as the working group and not pay the timeout
// of the broken one on every query.
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
broken := netip.MustParseAddrPort("192.0.2.1:53")
working := netip.MustParseAddrPort("192.0.2.2:53")
successAnswer := "192.0.2.100"
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
// Force the broken upstream to only unblock via timeout /
// cancellation so the assertion below can't pass if races
// were run serially.
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: 250 * time.Millisecond,
}
resolver.addRace([]netip.AddrPort{broken})
resolver.addRace([]netip.AddrPort{working})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
start := time.Now()
resolver.ServeDNS(responseWriter, inputMSG)
elapsed := time.Since(start)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
require.NotEmpty(t, responseMSG.Answer)
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
// Working group answers in a single RTT; the broken group's
// timeout (100ms) must not block the response.
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
}
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
// resolver returns SERVFAIL rather than leaking a partial response.
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a})
resolver.addRace([]netip.AddrPort{b})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
require.NotNil(t, responseMSG)
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
}
// TestUpstreamResolver_HealthTracking verifies that query-path results are
// recorded into per-upstream health, which is what projects back to
// NSGroupState for status reporting.
func TestUpstreamResolver_HealthTracking(t *testing.T) {
ok := netip.MustParseAddrPort("192.0.2.10:53")
bad := netip.MustParseAddrPort("192.0.2.11:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{ok, bad})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, ok)
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
assert.Empty(t, health[ok].LastErr)
// bad upstream was never tried because ok answered first; its health
// should remain unset.
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
var receivedUDPSize atomic.Uint32
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
receivedUDPSize.Store(uint32(opt.UDPSize()))
}
m := new(dns.Msg)
m.SetReply(r)
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
"upstream should see capped EDNS0, not the client's 4096")
}
@@ -874,7 +951,7 @@ func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: tracking,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamServers: []upstreamRace{{upstream1, upstream2}},
upstreamTimeout: UpstreamTimeout,
}

View File

@@ -512,16 +512,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
}
}
}
return false
})
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
@@ -1386,9 +1377,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
go e.dnsServer.ProbeAvailability()
return nil
}
@@ -1932,7 +1920,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:

View File

@@ -217,14 +217,6 @@ type Status struct {
eventStreams map[string]chan *proto.SystemEvent
eventQueue *EventQueue
// stateChangeStreams fan-out connection-state changes (connected /
// disconnected / connecting / address change / peers list change) to
// every active SubscribeStatus gRPC stream. Each subscriber gets a
// buffered chan; the notifier non-blockingly pings them so a slow
// consumer can never stall the daemon.
stateChangeMux sync.Mutex
stateChangeStreams map[string]chan struct{}
ingressGwMgr *ingressgw.Manager
routeIDLookup routeIDLookup
@@ -238,7 +230,6 @@ func NewRecorder(mgmAddress string) *Status {
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
eventStreams: make(map[string]chan *proto.SystemEvent),
eventQueue: NewEventQueue(eventQueueSize),
stateChangeStreams: make(map[string]chan struct{}),
offlinePeers: make([]State, 0),
notifier: newNotifier(),
mgmAddress: mgmAddress,
@@ -369,7 +360,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
d.notifyStateChange()
return nil
}
@@ -395,7 +385,6 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
// todo: consider to make sense of this notification or not
d.notifier.peerListChanged(numPeers)
d.notifyStateChange()
return nil
}
@@ -421,7 +410,6 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
// todo: consider to make sense of this notification or not
d.notifier.peerListChanged(numPeers)
d.notifyStateChange()
return nil
}
@@ -471,7 +459,6 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
d.notifyStateChange()
return nil
}
@@ -508,7 +495,6 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
d.notifyStateChange()
return nil
}
@@ -544,7 +530,6 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
d.notifyStateChange()
return nil
}
@@ -583,7 +568,6 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
d.notifyStateChange()
return nil
}
@@ -677,7 +661,6 @@ func (d *Status) FinishPeerListModifications() {
for _, rd := range dispatches {
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
}
d.notifyStateChange()
}
func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription {
@@ -736,7 +719,6 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip)
d.notifyStateChange()
}
// AddLocalPeerStateRoute adds a route to the local peer state
@@ -805,7 +787,6 @@ func (d *Status) CleanLocalPeerState() {
d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip)
d.notifyStateChange()
}
// MarkManagementDisconnected sets ManagementState to disconnected
@@ -818,7 +799,6 @@ func (d *Status) MarkManagementDisconnected(err error) {
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
d.notifyStateChange()
}
// MarkManagementConnected sets ManagementState to connected
@@ -831,7 +811,6 @@ func (d *Status) MarkManagementConnected() {
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
d.notifyStateChange()
}
// UpdateSignalAddress update the address of the signal server
@@ -872,7 +851,6 @@ func (d *Status) MarkSignalDisconnected(err error) {
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
d.notifyStateChange()
}
// MarkSignalConnected sets SignalState to connected
@@ -885,7 +863,6 @@ func (d *Status) MarkSignalConnected() {
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
d.notifyStateChange()
}
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
@@ -1083,19 +1060,16 @@ func (d *Status) GetFullStatus() FullStatus {
// ClientStart will notify all listeners about the new service state
func (d *Status) ClientStart() {
d.notifier.clientStart()
d.notifyStateChange()
}
// ClientStop will notify all listeners about the new service state
func (d *Status) ClientStop() {
d.notifier.clientStop()
d.notifyStateChange()
}
// ClientTeardown will notify all listeners about the service is under teardown
func (d *Status) ClientTeardown() {
d.notifier.clientTearDown()
d.notifyStateChange()
}
// SetConnectionListener set a listener to the notifier
@@ -1237,62 +1211,6 @@ func (d *Status) GetEventHistory() []*proto.SystemEvent {
return d.eventQueue.GetAll()
}
// SubscribeToStateChanges hands back a channel that receives a tick on
// every connection-state change (connected / disconnected / connecting /
// address change / peers-list change). The channel is buffered to one
// pending tick so a coalesced burst still wakes the consumer exactly
// once. Pass the returned id to UnsubscribeFromStateChanges to detach.
func (d *Status) SubscribeToStateChanges() (string, <-chan struct{}) {
d.stateChangeMux.Lock()
defer d.stateChangeMux.Unlock()
id := uuid.New().String()
ch := make(chan struct{}, 1)
d.stateChangeStreams[id] = ch
return id, ch
}
// UnsubscribeFromStateChanges releases a SubscribeToStateChanges channel
// and closes it so any consumer goroutine selecting on the channel
// unblocks cleanly.
func (d *Status) UnsubscribeFromStateChanges(id string) {
d.stateChangeMux.Lock()
defer d.stateChangeMux.Unlock()
if ch, ok := d.stateChangeStreams[id]; ok {
close(ch)
delete(d.stateChangeStreams, id)
}
}
// notifyStateChange wakes every SubscribeToStateChanges subscriber. Drops
// the tick if a subscriber's buffer is full — by definition the consumer
// is already going to fetch the latest snapshot, so multiple pending ticks
// would be redundant.
func (d *Status) notifyStateChange() {
d.stateChangeMux.Lock()
defer d.stateChangeMux.Unlock()
for _, ch := range d.stateChangeStreams {
select {
case ch <- struct{}{}:
default:
}
}
}
// NotifyStateChange is the public wake-the-subscribers entry point used by
// callers that mutate state outside the peer recorder — most importantly
// the connect-state machine, which writes StatusNeedsLogin into the
// shared contextState (client/internal/state.go) without touching any
// recorder field. Without this push the SubscribeStatus stream stays on
// the previous snapshot until an unrelated peer/management/signal
// change happens to fire notifyStateChange, leaving the UI's status
// out of sync with the daemon.
func (d *Status) NotifyStateChange() {
d.notifyStateChange()
}
func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
d.mux.Lock()
defer d.mux.Unlock()

View File

@@ -53,6 +53,7 @@ type Manager interface {
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetSelectedClientRoutes() route.HAMap
GetActiveClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
}
// GetActiveClientRoutes returns the subset of selected client routes
// that are currently reachable: the route's peer is Connected and is
// the one actively carrying the route (not just an HA sibling).
func (m *DefaultManager) GetActiveClientRoutes() route.HAMap {
m.mux.Lock()
selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
recorder := m.statusRecorder
m.mux.Unlock()
if recorder == nil {
return selected
}
out := make(route.HAMap, len(selected))
for id, routes := range selected {
for _, r := range routes {
st, err := recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if st.ConnStatus != peer.StatusConnected {
continue
}
if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute {
continue
}
out[id] = routes
break
}
}
return out
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()
@@ -704,7 +738,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
}
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
if len(routes) == 0 {
return false
}
return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network)
}
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {

View File

@@ -19,6 +19,7 @@ type MockManager struct {
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetSelectedClientRoutesFunc func() route.HAMap
GetActiveClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
@@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
return nil
}
// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface
func (m *MockManager) GetActiveClientRoutes() route.HAMap {
if m.GetActiveClientRoutesFunc != nil {
return m.GetActiveClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -12,10 +13,6 @@ import (
"github.com/netbirdio/netbird/route"
)
const (
exitNodeCIDR = "0.0.0.0/0"
)
type RouteSelector struct {
mu sync.RWMutex
deselectedRoutes map[route.NetID]struct{}
@@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
isSelected := !deselected
return isSelected
return rs.isSelectedLocked(routeID)
}
// FilterSelected removes unselected routes from the provided map.
@@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
filtered := route.HAMap{}
for id, rt := range routes {
netID := id.NetID()
_, deselected := rs.deselectedRoutes[netID]
if !deselected {
if !rs.isDeselectedLocked(id.NetID()) {
filtered[id] = rt
}
}
return filtered
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
filtered := make(route.HAMap, len(routes))
for id, rt := range routes {
netID := id.NetID()
if rs.isDeselected(netID) {
if rs.isDeselectedLocked(netID) {
continue
}
@@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected || rs.deselectAll
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
@@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter(
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelections() {
// user made explicit selects/deselects
if rs.IsSelected(netID) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func (rs *RouteSelector) hasUserSelections() bool {
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {

View File

@@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
})
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
})
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-default-route entry named "corp-v6" is not an exit node and
// must not be skipped because its v4 base "corp" is deselected.
corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")}
routes := route.HAMap{
"corp-v6|10.0.0.0/8": {corpV6},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"),
"non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes")
})
}
// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's
// SkipAutoApply flag governs each untouched route independently, even when
// the user has explicit selections on other routes.
func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) {
autoApplied := &route.Route{
NetID: "Auto",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: false,
}
skipApply := &route.Route{
NetID: "Skip",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"Auto|0.0.0.0/0": {autoApplied},
"Skip|0.0.0.0/0": {skipApply},
}
rs := routeselector.NewRouteSelector()
// User makes an unrelated explicit selection elsewhere.
require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"}))
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included")
assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection")
}
// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized
// as exit nodes by the selector's filter path.
func TestRouteSelector_V6ExitIsExitNode(t *testing.T) {
v6Exit := &route.Route{
NetID: "V6Only",
Network: netip.MustParsePrefix("::/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"V6Only|::/0": {v6Exit},
}
rs := routeselector.NewRouteSelector()
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply")
}
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialRoutes := []route.NetID{"route1", "route2", "route3"}
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}

View File

@@ -33,34 +33,17 @@ func CtxGetState(ctx context.Context) *contextState {
}
type contextState struct {
err error
status StatusType
mutex sync.Mutex
onChange func()
}
// SetOnChange installs a callback fired after every successful Set. Used by
// the daemon to wire the status recorder's notifyStateChange so any
// state.Set in the connect/login paths pushes a fresh snapshot to
// SubscribeStatus subscribers without each callsite having to opt in.
// The callback runs outside the contextState mutex to avoid a lock-order
// dependency with the recorder's stateChangeMux.
func (c *contextState) SetOnChange(fn func()) {
c.mutex.Lock()
c.onChange = fn
c.mutex.Unlock()
err error
status StatusType
mutex sync.Mutex
}
func (c *contextState) Set(update StatusType) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.status = update
c.err = nil
cb := c.onChange
c.mutex.Unlock()
if cb != nil {
cb()
}
}
func (c *contextState) Status() (StatusType, error) {
@@ -74,17 +57,6 @@ func (c *contextState) Status() (StatusType, error) {
return c.status, nil
}
// CurrentStatus returns the last status set via Set, ignoring any wrapped
// error. Use when the status is needed for reporting purposes (e.g. the
// status snapshot stream) and a transient wrapped error from a retry loop
// shouldn't blank out the underlying status.
func (c *contextState) CurrentStatus() StatusType {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.status
}
func (c *contextState) Wrap(err error) error {
c.mutex.Lock()
defer c.mutex.Unlock()

View File

@@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
hostDNS := []netip.AddrPort{
netip.MustParseAddrPort("9.9.9.9:53"),
netip.MustParseAddrPort("149.112.112.112:53"),
}
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
// Stop the internal client and free the resources

View File

@@ -32,6 +32,9 @@
</File>
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
<?if $(var.ArchSuffix) = "amd64" ?>
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
<?endif ?>
<ServiceInstall
Id="NetBirdService"
@@ -59,14 +62,6 @@
<Component Id="NetbirdAumidRegistry" Guid="*">
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
<!-- Pre-seed the CLSID the Wails notifications service reads on
first startup (notifications_windows.go:getGUID looks for
the CustomActivator value under this key). Without this
the service generates a fresh per-install UUID, which
diverges from the ToastActivatorCLSID set on the Start
Menu / Desktop shortcuts above and the COM activator
never fires when a toast is clicked. -->
<RegistryValue Name="CustomActivator" Type="string" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
</RegistryKey>
</Component>
</StandardDirectory>
@@ -90,40 +85,10 @@
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" TerminateProcess="0" />
<!-- WebView2 evergreen runtime detection.
Probe both the per-machine and per-user EdgeUpdate keys; if either
reports a non-empty `pv` value the runtime is already installed
and we skip the bootstrapper. -->
<Property Id="WEBVIEW2_VERSION_HKLM">
<RegistrySearch Id="WV2HKLM" Root="HKLM"
Key="SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
Name="pv" Type="raw" Bitness="always64" />
</Property>
<Property Id="WEBVIEW2_VERSION_HKCU">
<RegistrySearch Id="WV2HKCU" Root="HKCU"
Key="Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
Name="pv" Type="raw" />
</Property>
<!-- Embed the bootstrapper payload. Path is relative to the WiX
working directory; sign-pipelines stages it next to client/
via `wails3 generate webview2bootstrapper`. -->
<Binary Id="WebView2Bootstrapper" SourceFile=".\client\MicrosoftEdgeWebview2Setup.exe" />
<CustomAction Id="InstallWebView2"
BinaryRef="WebView2Bootstrapper"
ExeCommand="/silent /install"
Execute="deferred"
Impersonate="no"
Return="check" />
<InstallExecuteSequence>
<Custom Action="InstallWebView2" Before="InstallFinalize"
Condition="NOT WEBVIEW2_VERSION_HKLM AND NOT WEBVIEW2_VERSION_HKCU AND NOT REMOVE" />
</InstallExecuteSequence>
<!-- Icons -->
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\build\windows\icon.ico" />
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" />
<Property Id="ARPPRODUCTICON" Value="NetbirdIcon" />
</Package>

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v3.21.12
// protoc v6.33.1
// source: daemon.proto
package proto
@@ -823,15 +823,9 @@ func (x *WaitSSOLoginResponse) GetEmail() string {
}
type UpRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
// async instructs the daemon to start the connection attempt and return
// immediately without waiting for the engine to become ready. Status updates
// are delivered via the SubscribeStatus stream. When false (the default) the
// RPC blocks until the engine is running or gives up, which is the behaviour
// needed by the CLI.
Async bool `protobuf:"varint,4,opt,name=async,proto3" json:"async,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -880,13 +874,6 @@ func (x *UpRequest) GetUsername() string {
return ""
}
func (x *UpRequest) GetAsync() bool {
if x != nil {
return x.Async
}
return false
}
type UpResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -6322,11 +6309,10 @@ const file_daemon_proto_rawDesc = "" +
"\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" +
"\bhostname\x18\x02 \x01(\tR\bhostname\",\n" +
"\x14WaitSSOLoginResponse\x12\x14\n" +
"\x05email\x18\x01 \x01(\tR\x05email\"\x8c\x01\n" +
"\x05email\x18\x01 \x01(\tR\x05email\"v\n" +
"\tUpRequest\x12%\n" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12\x14\n" +
"\x05async\x18\x04 \x01(\bR\x05asyncB\x0e\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_usernameJ\x04\b\x03\x10\x04\"\f\n" +
"\n" +
@@ -6787,13 +6773,12 @@ const file_daemon_proto_rawDesc = "" +
"\n" +
"EXPOSE_UDP\x10\x03\x12\x0e\n" +
"\n" +
"EXPOSE_TLS\x10\x042\xf5\x17\n" +
"EXPOSE_TLS\x10\x042\xaf\x17\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
"\x02Up\x12\x11.daemon.UpRequest\x1a\x12.daemon.UpResponse\"\x00\x129\n" +
"\x06Status\x12\x15.daemon.StatusRequest\x1a\x16.daemon.StatusResponse\"\x00\x12D\n" +
"\x0fSubscribeStatus\x12\x15.daemon.StatusRequest\x1a\x16.daemon.StatusResponse\"\x000\x01\x123\n" +
"\x06Status\x12\x15.daemon.StatusRequest\x1a\x16.daemon.StatusResponse\"\x00\x123\n" +
"\x04Down\x12\x13.daemon.DownRequest\x1a\x14.daemon.DownResponse\"\x00\x12B\n" +
"\tGetConfig\x12\x18.daemon.GetConfigRequest\x1a\x19.daemon.GetConfigResponse\"\x00\x12K\n" +
"\fListNetworks\x12\x1b.daemon.ListNetworksRequest\x1a\x1c.daemon.ListNetworksResponse\"\x00\x12Q\n" +
@@ -6994,84 +6979,82 @@ var file_daemon_proto_depIdxs = []int32{
7, // 38: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
9, // 39: daemon.DaemonService.Up:input_type -> daemon.UpRequest
11, // 40: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
11, // 41: daemon.DaemonService.SubscribeStatus:input_type -> daemon.StatusRequest
13, // 42: daemon.DaemonService.Down:input_type -> daemon.DownRequest
15, // 43: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
26, // 44: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
28, // 45: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
28, // 46: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
4, // 47: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
35, // 48: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
37, // 49: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
39, // 50: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
42, // 51: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
44, // 52: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
46, // 53: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
48, // 54: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
51, // 55: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
92, // 56: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest
94, // 57: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest
96, // 58: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest
54, // 59: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
56, // 60: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
58, // 61: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
60, // 62: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
62, // 63: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
64, // 64: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
66, // 65: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
69, // 66: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
71, // 67: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
73, // 68: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
75, // 69: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
77, // 70: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
79, // 71: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 72: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
83, // 73: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
85, // 74: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
87, // 75: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
89, // 76: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
6, // 77: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
8, // 78: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
10, // 79: daemon.DaemonService.Up:output_type -> daemon.UpResponse
12, // 80: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
12, // 81: daemon.DaemonService.SubscribeStatus:output_type -> daemon.StatusResponse
14, // 82: daemon.DaemonService.Down:output_type -> daemon.DownResponse
16, // 83: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
27, // 84: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
29, // 85: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
29, // 86: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
34, // 87: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
36, // 88: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
38, // 89: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
40, // 90: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
43, // 91: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
45, // 92: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
47, // 93: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
49, // 94: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
53, // 95: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
93, // 96: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket
95, // 97: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse
97, // 98: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse
55, // 99: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
57, // 100: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
59, // 101: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
61, // 102: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
63, // 103: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
65, // 104: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
67, // 105: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
70, // 106: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
72, // 107: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
74, // 108: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
76, // 109: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
78, // 110: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 111: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 112: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
84, // 113: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
86, // 114: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
88, // 115: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
90, // 116: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
77, // [77:117] is the sub-list for method output_type
37, // [37:77] is the sub-list for method input_type
13, // 41: daemon.DaemonService.Down:input_type -> daemon.DownRequest
15, // 42: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
26, // 43: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
28, // 44: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
28, // 45: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
4, // 46: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
35, // 47: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
37, // 48: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
39, // 49: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
42, // 50: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
44, // 51: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
46, // 52: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
48, // 53: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
51, // 54: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
92, // 55: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest
94, // 56: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest
96, // 57: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest
54, // 58: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
56, // 59: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
58, // 60: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
60, // 61: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
62, // 62: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
64, // 63: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
66, // 64: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
69, // 65: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
71, // 66: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
73, // 67: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
75, // 68: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
77, // 69: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
79, // 70: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 71: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
83, // 72: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
85, // 73: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
87, // 74: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
89, // 75: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
6, // 76: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
8, // 77: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
10, // 78: daemon.DaemonService.Up:output_type -> daemon.UpResponse
12, // 79: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
14, // 80: daemon.DaemonService.Down:output_type -> daemon.DownResponse
16, // 81: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
27, // 82: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
29, // 83: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
29, // 84: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
34, // 85: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
36, // 86: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
38, // 87: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
40, // 88: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
43, // 89: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
45, // 90: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
47, // 91: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
49, // 92: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
53, // 93: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
93, // 94: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket
95, // 95: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse
97, // 96: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse
55, // 97: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
57, // 98: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
59, // 99: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
61, // 100: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
63, // 101: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
65, // 102: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
67, // 103: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
70, // 104: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
72, // 105: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
74, // 106: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
76, // 107: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
78, // 108: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 109: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 110: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
84, // 111: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
86, // 112: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
88, // 113: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
90, // 114: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
76, // [76:115] is the sub-list for method output_type
37, // [37:76] is the sub-list for method input_type
37, // [37:37] is the sub-list for extension type_name
37, // [37:37] is the sub-list for extension extendee
0, // [0:37] is the sub-list for field type_name

View File

@@ -24,12 +24,6 @@ service DaemonService {
// Status of the service.
rpc Status(StatusRequest) returns (StatusResponse) {}
// SubscribeStatus pushes a fresh StatusResponse on connection state
// changes (Connected / Disconnected / Connecting / address change /
// peers list change). The first message on the stream is the current
// snapshot, so a freshly-subscribed UI doesn't need to also call Status.
rpc SubscribeStatus(StatusRequest) returns (stream StatusResponse) {}
// Down stops engine work in the daemon.
rpc Down(DownRequest) returns (DownResponse) {}
@@ -233,12 +227,6 @@ message UpRequest {
optional string profileName = 1;
optional string username = 2;
reserved 3;
// async instructs the daemon to start the connection attempt and return
// immediately without waiting for the engine to become ready. Status updates
// are delivered via the SubscribeStatus stream. When false (the default) the
// RPC blocks until the engine is running or gives up, which is the behaviour
// needed by the CLI.
bool async = 4;
}
message UpResponse {}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,93 @@
package server
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
func TestPersistLoginOverrides(t *testing.T) {
strPtr := func(s string) *string { return &s }
tests := []struct {
name string
initialMgmtURL string
initialPSK string
newMgmtURL string
newPSK *string
wantMgmtURL string
wantPSK string
}{
{
name: "persist new management URL",
initialMgmtURL: "https://old.example.com:33073",
newMgmtURL: "https://new.example.com:33073",
wantMgmtURL: "https://new.example.com:33073",
},
{
name: "persist new pre-shared key",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "old-key",
newPSK: strPtr("new-key"),
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "new-key",
},
{
name: "persist both",
initialMgmtURL: "https://old.example.com:33073",
initialPSK: "old-key",
newMgmtURL: "https://new.example.com:33073",
newPSK: strPtr("new-key"),
wantMgmtURL: "https://new.example.com:33073",
wantPSK: "new-key",
},
{
name: "no inputs preserves existing",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "existing-key",
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "existing-key",
},
{
name: "empty PSK pointer is ignored",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "existing-key",
newPSK: strPtr(""),
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "existing-key",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
origDefault := profilemanager.DefaultConfigPath
t.Cleanup(func() { profilemanager.DefaultConfigPath = origDefault })
dir := t.TempDir()
profilemanager.DefaultConfigPath = filepath.Join(dir, "default.json")
seed := profilemanager.ConfigInput{
ConfigPath: profilemanager.DefaultConfigPath,
ManagementURL: tt.initialMgmtURL,
}
if tt.initialPSK != "" {
seed.PreSharedKey = strPtr(tt.initialPSK)
}
_, err := profilemanager.UpdateOrCreateConfig(seed)
require.NoError(t, err, "seed config")
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
require.NoError(t, err, "persistLoginOverrides")
cfg, err := profilemanager.ReadConfig(profilemanager.DefaultConfigPath)
require.NoError(t, err, "read back config")
require.Equal(t, tt.wantMgmtURL, cfg.ManagementURL.String(), "management URL")
require.Equal(t, tt.wantPSK, cfg.PreSharedKey, "pre-shared key")
})
}
}

View File

@@ -140,15 +140,6 @@ func (s *Server) Start() error {
}
state := internal.CtxGetState(s.rootCtx)
// Every contextState.Set in the connect/login/server paths must push a
// SubscribeStatus snapshot, otherwise transitions that don't happen to
// be accompanied by a Mark{Management,Signal,...} call (e.g. plain
// StatusNeedsLogin after a PermissionDenied login, StatusLoginFailed
// after OAuth init failure, StatusIdle in the Login defer) leave the
// UI stuck on the previous status until the next unrelated peer event.
// Binding the recorder here means new state.Set callsites don't have
// to opt in individually.
state.SetOnChange(s.statusRecorder.NotifyStateChange)
if err := handlePanicLog(); err != nil {
log.Warnf("failed to redirect stderr: %v", err)
@@ -229,20 +220,10 @@ func (s *Server) Start() error {
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
// close(giveUpChan) MUST run on every exit path (DisableAutoConnect
// return, backoff.Retry return, panic) — Down() blocks for up to 5s
// waiting on this signal before flipping the state to Idle, and a
// missed close leaves Down() always hitting the timeout. The signal
// fires AFTER clientRunning=false is committed under the mutex so a
// Down/Up racing with the goroutine exit never observes a half-state
// (chan closed but clientRunning still true).
defer func() {
s.mutex.Lock()
s.clientRunning = false
s.mutex.Unlock()
if giveUpChan != nil {
close(giveUpChan)
}
}()
if s.config.DisableAutoConnect {
@@ -277,15 +258,6 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
runOperation := func() error {
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
if err != nil {
// PermissionDenied means the daemon transitioned to NeedsLogin
// inside connect(). Without backoff.Permanent the outer retry
// re-enters connect(), which resets the state to Connecting and
// makes the tray flicker between NeedsLogin and Connecting until
// the user logs in. Stop retrying and let the state stick.
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
log.Debugf("run client connection exited with PermissionDenied, waiting for login")
return backoff.Permanent(err)
}
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
}
@@ -297,6 +269,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
if err := backoff.Retry(runOperation, backOff); err != nil {
log.Errorf("operation failed: %v", err)
}
if giveUpChan != nil {
close(giveUpChan)
}
}
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
@@ -365,7 +341,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
}
if msg.OptionalPreSharedKey != nil {
config.PreSharedKey = msg.OptionalPreSharedKey
if *msg.OptionalPreSharedKey != "" {
config.PreSharedKey = msg.OptionalPreSharedKey
}
}
if msg.CleanDNSLabels {
@@ -512,6 +490,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.mutex.Unlock()
if err := persistLoginOverrides(activeProf, msg.ManagementUrl, msg.OptionalPreSharedKey); err != nil {
log.Errorf("failed to persist login overrides: %v", err)
return nil, fmt.Errorf("persist login overrides: %w", err)
}
config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
@@ -586,35 +569,8 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return &proto.LoginResponse{}, nil
}
// WaitSSOLogin validates the supplied userCode against the in-flight OAuth
// device/PKCE flow and blocks until the user finishes the browser leg.
//
// State transitions on exit:
//
// ┌──────────────────────────────────────────┬──────────────────────────────────┐
// │ Outcome │ contextState │
// ├──────────────────────────────────────────┼──────────────────────────────────┤
// │ Success → loginAttempt → Connected │ StatusConnected (loginAttempt) │
// │ Success → loginAttempt → still-NeedsLogin│ StatusNeedsLogin (loginAttempt) │
// │ Success → loginAttempt error │ StatusLoginFailed (loginAttempt) │
// │ UserCode mismatch │ StatusLoginFailed │
// │ WaitToken: context.Canceled (external │ defer runs: status untouched if │
// │ abort — profile switch invokes │ already NeedsLogin/LoginFailed,│
// │ actCancel/waitCancel, app quit, │ else StatusIdle. Keeps the │
// │ another WaitSSOLogin started) │ cancel from leaking as a │
// │ │ spurious LoginFailed on the │
// │ │ next profile's Up. │
// │ WaitToken: context.DeadlineExceeded │ StatusNeedsLogin │
// │ (OAuth device-code window expired │ (retryable; the UI's "Connect" │
// │ while waiting on the browser leg) │ re-enters the Login flow) │
// │ WaitToken: any other error │ StatusLoginFailed │
// │ (access_denied, expired_token, HTTP │ (genuine auth/IO failure; │
// │ failure, token validation rejection) │ surfaced verbatim to caller) │
// └──────────────────────────────────────────┴──────────────────────────────────┘
//
// The defer at the top of the function applies the Idle fallback so callers
// that bypass the explicit Set calls (the Canceled branch above, the success
// path before loginAttempt) still land on a sensible terminal status.
// WaitSSOLogin uses the userCode to validate the TokenInfo and
// waits for the user to continue with the login on a browser
func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLoginRequest) (*proto.WaitSSOLoginResponse, error) {
s.mutex.Lock()
if s.actCancel != nil {
@@ -674,21 +630,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.mutex.Lock()
s.oauthAuthFlow.expiresAt = time.Now()
s.mutex.Unlock()
switch {
case errors.Is(err, context.Canceled):
// External abort (profile switch, app quit, another
// WaitSSOLogin started). Not a login failure — let the
// top-level defer fall through to StatusIdle so the next
// flow starts from a clean state.
case errors.Is(err, context.DeadlineExceeded):
// OAuth device-code window expired with no user action.
// Retryable — leave the daemon in NeedsLogin so the UI
// keeps the Login affordance instead of reading as a
// hard failure.
state.Set(internal.StatusNeedsLogin)
default:
state.Set(internal.StatusLoginFailed)
}
state.Set(internal.StatusLoginFailed)
log.Errorf("waiting for browser login failed: %v", err)
return nil, err
}
@@ -803,9 +745,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
s.mutex.Unlock()
if msg.GetAsync() {
return &proto.UpResponse{}, nil
}
return s.waitForUp(callerCtx)
}
@@ -905,37 +844,23 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
return nil, err
}
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
s.mutex.Unlock()
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
// The giveUpChan is closed by the goroutine's deferred cleanup (see
// connectWithRetryRuns) on every exit path. A timeout here typically
// means the goroutine is still wedged inside a slow teardown step.
// The giveUpChan is closed at the end of connectWithRetryRuns.
if giveUpChan != nil {
select {
case <-giveUpChan:
log.Debugf("client goroutine finished, giveUpChan closed")
log.Debugf("client goroutine finished successfully")
case <-time.After(5 * time.Second):
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
}
}
// Set Idle only after the retry goroutine has exited (or timed out).
// Setting it earlier races with the goroutine's own Set(StatusConnecting)
// at the top of each retry attempt, which would leave the snapshot
// stuck at Connecting long after the user asked to disconnect.
internal.CtxGetState(s.rootCtx).Set(internal.StatusIdle)
// Clear stale management/signal errors so the next Up() (typically for a
// different profile) starts with a clean status snapshot. Without this,
// a managementError left over from a LoginFailed cycle persists in the
// statusRecorder and appears in the new profile's initial
// SubscribeStatus snapshot, making the new profile look like it also
// failed to log in.
s.statusRecorder.MarkManagementDisconnected(nil)
s.statusRecorder.MarkSignalDisconnected(nil)
return &proto.DownResponse{}, nil
}
@@ -1044,7 +969,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
return &proto.LogoutResponse{}, nil
}
// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
// getConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) {
cfgPath, err := activeProf.FilePath()
if err != nil {
@@ -1189,23 +1114,9 @@ func (s *Server) Status(
}
}
return s.buildStatusResponse(msg)
}
// buildStatusResponse composes a StatusResponse from the current daemon
// state. Shared between the unary Status RPC and the SubscribeStatus
// stream so both paths return identical snapshots.
func (s *Server) buildStatusResponse(msg *proto.StatusRequest) (*proto.StatusResponse, error) {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
status, err := internal.CtxGetState(s.rootCtx).Status()
if err != nil {
// state.Status() blanks the status when err is set (e.g. management
// retry loop wrapped a connection error). The underlying status is
// still meaningful and the failure is already surfaced via
// FullStatus.ManagementState.Error, so don't propagate err — that
// would tear down the SubscribeStatus stream and cause the UI to
// mark the daemon as unreachable on every retry.
status = state.CurrentStatus()
return nil, err
}
if status == internal.StatusNeedsLogin && s.isSessionActive.Load() {
@@ -1860,3 +1771,29 @@ func sendTerminalNotification() error {
return wallCmd.Wait()
}
// persistLoginOverrides writes management URL and pre-shared key from a LoginRequest to the
// active profile config so that subsequent reads pick them up. Empty/nil values are ignored.
func persistLoginOverrides(activeProf *profilemanager.ActiveProfileState, managementURL string, preSharedKey *string) error {
if preSharedKey != nil && *preSharedKey == "" {
preSharedKey = nil
}
if managementURL == "" && preSharedKey == nil {
return nil
}
cfgPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("active profile file path: %w", err)
}
input := profilemanager.ConfigInput{
ConfigPath: cfgPath,
ManagementURL: managementURL,
PreSharedKey: preSharedKey,
}
if _, err := profilemanager.UpdateOrCreateConfig(input); err != nil {
return fmt.Errorf("update config: %w", err)
}
return nil
}

View File

@@ -1,57 +0,0 @@
package server
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
// SubscribeStatus pushes a fresh StatusResponse on every connection state
// change. The first message is the current snapshot, so a re-subscribing
// client doesn't need to also call Status. Subsequent messages fire when
// the peer recorder reports any of: connected/disconnected/connecting,
// management or signal flip, address change, or peers list change.
//
// The change channel coalesces bursts to a single tick. If the consumer
// is slow the daemon drops extras (not blocks), and the next snapshot
// the consumer pulls already reflects everything.
func (s *Server) SubscribeStatus(req *proto.StatusRequest, stream proto.DaemonService_SubscribeStatusServer) error {
subID, ch := s.statusRecorder.SubscribeToStateChanges()
defer func() {
s.statusRecorder.UnsubscribeFromStateChanges(subID)
log.Debug("client unsubscribed from status updates")
}()
log.Debug("client subscribed to status updates")
if err := s.sendStatusSnapshot(req, stream); err != nil {
return err
}
for {
select {
case _, ok := <-ch:
if !ok {
return nil
}
if err := s.sendStatusSnapshot(req, stream); err != nil {
return err
}
case <-stream.Context().Done():
return nil
}
}
}
func (s *Server) sendStatusSnapshot(req *proto.StatusRequest, stream proto.DaemonService_SubscribeStatusServer) error {
resp, err := s.buildStatusResponse(req)
if err != nil {
log.Warnf("build status snapshot for stream: %v", err)
return err
}
if err := stream.Send(resp); err != nil {
log.Warnf("send status snapshot to stream: %v", err)
return err
}
return nil
}

View File

@@ -252,6 +252,10 @@ func (m *Manager) writeSSHConfig(sshConfig string) error {
return fmt.Errorf("write SSH config file %s: %w", tmpPath, err)
}
if err := os.Chmod(tmpPath, 0644); err != nil {
return fmt.Errorf("chmod SSH config file %s: %w", tmpPath, err)
}
if err := os.Rename(tmpPath, sshConfigPath); err != nil {
return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err)
}

View File

@@ -1,8 +0,0 @@
.task
bin
frontend/dist
frontend/node_modules
frontend/bindings
frontend/.vite
build/linux/appimage/build
build/windows/nsis/MicrosoftEdgeWebview2Setup.exe

View File

@@ -1,265 +0,0 @@
# NetBird Wails UI — Working Notes
This is the Wails v3 desktop UI for NetBird. Go services live in `services/`; the React/TS frontend lives in `frontend/`; bindings between them are generated under `frontend/bindings/`.
## Layout
### Go (top-level package `main`)
- `main.go` — app entry. Builds the gRPC `Conn`, constructs services, registers them with Wails, creates the main webview window, starts the in-process Linux SNI watcher, then the tray, then `peers.Watch`, then `app.Run`. Also wires `--daemon-addr`, `--log-file` (repeatable, defaults to `console`), `--log-level` flags.
- `tray.go``Tray` struct and its menu. Subscribes to `EventStatus`, `EventSystem`, `EventUpdateAvailable`, `EventUpdateProgress`. Owns the per-status icon/dot, the Profiles submenu, the Connect/Disconnect swap, the About → Update flow, session-expired toast.
- `tray_linux.go``init()` that sets `WEBKIT_DISABLE_DMABUF_RENDERER=1` to avoid the blank-white window on VMs / minimal WMs.
- `tray_watcher_linux.go`, `xembed_host_linux.go`, `xembed_tray_linux.{c,h}` — in-process `org.kde.StatusNotifierWatcher` and XEmbed bridge so the tray works on minimal WMs (Fluxbox, OpenBox, i3, dwm, vanilla GNOME without AppIndicator). See "Linux tray support" below.
- `tray_watcher_other.go` — no-op stub on non-Linux builds.
- `signal_unix.go` / `signal_windows.go``listenForShowSignal`. On Unix, SIGUSR1 brings the window forward. On Windows, a named event `Global\NetBirdQuickActionsTriggerEvent` does the same. Mirrors the legacy Fyne UI's external-trigger contract so the installer / CLI keep working.
- `grpc.go``Conn` is a lazy, mutex-protected gRPC channel to the daemon. One `Conn` is shared by every service. `DaemonAddr()` returns `unix:///var/run/netbird.sock` on Linux/macOS and `tcp://127.0.0.1:41731` on Windows.
- `icons.go``//go:embed` the tray/window PNGs from `assets/`. macOS uses template variants (`*-macos.png`); Linux ships light + dark PNGs; Windows reuses the light PNG (multi-frame `.ico` never redrew on Wails3's `NIM_MODIFY`).
- `desktop/desktop.go` — tiny helper returning `GetUIUserAgent()` (`netbird-desktop-ui/<version>`) for the gRPC dialer.
### Wails services (`services/*.go`)
Each service is registered via `app.RegisterService(application.NewService(svc))`. Every method becomes a TS function in `frontend/bindings/.../services/`. See "Services rundown" below.
### Frontend (`frontend/src/`)
- `app.tsx` — top-level routes. Hash router with `/quick`, `/browser-login`, `/update`, `/session-expired`, `/settings` (own layout), and a root `AppLayout` that hosts `Main` and a `*` catch-all.
- `layouts/AppLayout.tsx` — composition shell. Wraps `Header + Outlet` in `ProfileProvider → DebugBundleProvider → ClientVersionProvider`. The wide-panel `expanded` state lives here as plain `useState` (no persistence) and is passed to `Header` via props and `Main` via Outlet context.
- `layouts/SettingsLayout.tsx` — used when the settings window opens (route `/settings`).
- `modules/*/Context.tsx` — context providers (`auto-update`, `debug-bundle`, `profile`).
- `pages/` — full-screen, single-purpose pages opened in popups or via top-level routes (`BrowserLogin`, `SessionExpired`, `Update`, `Debug`).
- `screens/` — content shown inside `AppLayout` (`Status`, `Peers`, `Networks`, `Profiles`, `Settings`, `Update`, `QuickActions`, `Debug`).
### Generated bindings
- `frontend/bindings/**` — generated, **gitignored**, do not edit by hand. After editing any `services/*.go`, regenerate from this directory via `wails3 generate bindings -clean=true -ts` (or `pnpm bindings` from `frontend/`). Fresh clones need to run this once before `pnpm typecheck` will succeed; `wails3 dev` regenerates on its own.
## Services rundown
All services live in `services/` and assume a build tag `!android && !ios && !freebsd && !js`. Each takes a shared `DaemonConn` (`conn.go`) and is registered in `main.go`.
| Service | File | Responsibility |
|---|---|---|
| `Connection` | `connection.go` | `Login` / `WaitSSOLogin` / `Up` / `Down` / `Logout` / `OpenURL`. `Up` is always async (`Async: true`); status flows back through `Peers`. `Login` Down-resets the daemon first to dislodge a stale WaitSSOLogin. `OpenURL` honors `$BROWSER`. |
| `Settings` | `settings.go` | `GetConfig` / `SetConfig` (partial update — pointer fields are sent, nil fields preserved) / `GetFeatures` (operator-disabled UI surfaces). |
| `Profiles` | `profile.go` | `Username` / `List` / `GetActive` / `Switch` / `Add` / `Remove`. `List` populates `Email` from the **user-side** state file (`profilemanager.NewProfileManager().GetProfileState`) — the daemon runs as root and can't read it. |
| `ProfileSwitcher` | `profileswitcher.go` | `SwitchActive` — the single entry point both tray and frontend should use for profile flips. Applies the reconnect policy (see "Profile switching" below), mirrors the daemon switch into the user-side `profilemanager`, drives optimistic feedback via `Peers.BeginProfileSwitch`. |
| `Peers` | `peers.go` | Daemon status snapshot + two long-running streams (`SubscribeStatus``EventStatus`, `SubscribeEvents``EventSystem`). Emits synthetic `StatusDaemonUnavailable` when the socket is unreachable. Owns the profile-switch suppression filter (`BeginProfileSwitch` / `CancelProfileSwitch` / `shouldSuppress`). Fan-outs update metadata into dedicated `EventUpdateAvailable` / `EventUpdateProgress` events. |
| `Networks` | `network.go` | `List` / `Select` / `Deselect` of routed networks. |
| `Forwarding` | `forwarding.go` | `List` exposed/forwarded services from the daemon's reverse-proxy table. |
| `Debug` | `debug.go` | `Bundle` (debug bundle creation + optional upload) / `Get|SetLogLevel` / `RevealFile` (cross-platform "show in file manager"). |
| `Update` | `update.go` | `Trigger` (enforced installer) / `GetInstallerResult` / `Quit` (used by the `/update` page after a successful install). |
| `WindowManager` | `windowmanager.go` | `OpenSettings` / `OpenBrowserLogin(uri)` / `CloseBrowserLogin`. Auxiliary windows are created on first open and **destroyed** on close (Wails-recommended singleton pattern; prevents the macOS dock-reopen from resurrecting hidden windows). |
`DaemonConn` is defined in `services/conn.go`; `ptrStr` (string-to-*string helper for proto pointer fields) lives there too.
## Daemon proto
- Proto source: `../proto/daemon.proto`. Generated Go in `../proto/*.pb.go`.
- Regen: `cd ../proto && protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative daemon.proto`
- Pinned versions (see `daemon.pb.go` header): `protoc v7.34.1`, `protoc-gen-go v1.36.6`. CI's `proto-version-check` workflow fails on mismatch.
- After proto regen, also regen Wails bindings so the TS layer picks up new fields.
## Events bus
`main.go` registers four event types so the frontend can subscribe with typed payloads:
```go
application.RegisterEvent[services.Status](services.EventStatus) // "netbird:status"
application.RegisterEvent[services.SystemEvent](services.EventSystem) // "netbird:event"
application.RegisterEvent[services.UpdateAvailable](services.EventUpdateAvailable) // "netbird:update:available"
application.RegisterEvent[services.UpdateProgress](services.EventUpdateProgress) // "netbird:update:progress"
```
Two additional plain-string events flow between Go and JS without a typed payload:
- `EventTriggerLogin = "trigger-login"` — emitted by the tray (or other Go-side triggers) to ask the frontend's `startLogin()` orchestrator to begin an SSO flow.
- `EventBrowserLoginCancel = "browser-login:cancel"` — emitted by the `BrowserLogin` window when the user clicks Cancel or closes the window (red X). `startLogin()` listens and tears down the pending daemon SSO wait.
Daemon connection status strings (`services/peers.go`) — mirror `internal.Status*` in `client/internal/state.go`:
```go
StatusConnected, StatusConnecting, StatusIdle,
StatusNeedsLogin, StatusLoginFailed, StatusSessionExpired,
StatusDaemonUnavailable // synthetic, emitted by Peers when the socket is unreachable
```
## Profile switching
`services/profileswitcher.go` is the single source of truth for the reconnect policy. Both the tray (`tray.go switchProfile`) and the frontend's `screens/Profiles.tsx` call `ProfileSwitcher.SwitchActive`; identical inputs give identical state transitions.
Reconnect policy (driven by `prevStatus` from `Peers.Get`):
| Previous status | Action | Optimistic UI | Suppressed events until new flow begins |
|---|---|---|---|
| Connected | Switch + Down + Up | Connecting (synthetic) | Connected, Idle |
| Connecting | Switch + Down + Up | Connecting (unchanged) | Connected, Idle |
| NeedsLogin / LoginFailed / SessionExpired | Switch + Down | (no change) | — |
| Idle | Switch only | (no change) | — |
Only Connected/Connecting trigger `Peers.BeginProfileSwitch`. That:
1. Sets a 30s `switchInProgress` guard.
2. Emits a synthetic `Status{Status: StatusConnecting}` so both tray and React paint immediately.
3. Tells `statusStreamLoop` to drop the daemon's stale Connected updates (peer count drops as the engine tears down) and the transient Idle in between Down and the new Up.
`shouldSuppress` releases the guard as soon as a status that signals the new flow began arrives:
- **Suppressed**: Connected, Idle
- **Pass through and clear**: Connecting / NeedsLogin / LoginFailed / SessionExpired / DaemonUnavailable
- **Timeout fallback**: 30s elapsed → clear flag, emit normally.
`Peers.CancelProfileSwitch` aborts the suppression — called by `tray.go handleDisconnect` so the user's "Disconnect while Connecting" click paints through immediately.
Also: `ProfileSwitcher.SwitchActive` mirrors the daemon switch into the user-side `profilemanager` (`~/Library/Application Support/netbird/active_profile`). The CLI's `netbird up` reads this file and sends the resolved profile name back; if it diverges from the daemon's `/var/lib/netbird/active_profile.json`, the daemon silently flips back. Mirror failures don't abort the switch — surfaced as a warning.
## Auxiliary windows (`WindowManager`)
The main window is created up front in `main.go`. Auxiliary windows are created on demand by `services.WindowManager`:
- **Settings** (`/#/settings`) — opened from the header gear icon (`layouts/Header.tsx → WindowManager.OpenSettings`). Frameless-look (translucent macOS backdrop, hidden inset title bar), fixed 900×640, no resize, no minimise/maximise.
- **BrowserLogin** (`/#/browser-login?uri=…`) — opened by the connection toggle's SSO flow (`layouts/ConnectionStatusSwitch.tsx`). 460×440, fixed size. The close button (red X) fires `EventBrowserLoginCancel` so the JS-side `startLogin()` can tear down the daemon's pending `WaitSSOLogin`. `WindowManager.CloseBrowserLogin` closes it programmatically when the flow completes.
Both windows are **destroyed** on close (mutex-guarded singleton; `closing` hook nils the field). Destroying rather than hiding is deliberate — Wails' macOS dock-reopen handler resurrects hidden windows, which we don't want for auxiliaries.
The main window is **hidden** on close (the `WindowClosing` hook calls `e.Cancel(); window.Hide()`). The user reaches "really quit" through the tray → Quit menu entry.
## Linux tray support (StatusNotifierWatcher + XEmbed)
Minimal WMs (Fluxbox, OpenBox, i3, dwm, vanilla GNOME without the AppIndicator extension) don't ship a `StatusNotifierWatcher`, so tray icons using libayatana-appindicator / freedesktop StatusNotifier silently fail. `main.go` calls `startStatusNotifierWatcher()` *before* `NewTray` so the Wails systray's `RegisterStatusNotifierItem` call hits the in-process watcher we control.
- `tray_watcher_linux.go` — owns `org.kde.StatusNotifierWatcher` on the session bus if no other process has it. Safe to call unconditionally.
- `xembed_host_linux.go` + `xembed_tray_linux.{c,h}` — when an XEmbed tray (`_NET_SYSTEM_TRAY_S0`) is available, also start an in-process XEmbed host that bridges the SNI icon into the XEmbed tray. Reads `IconPixmap` over D-Bus, draws via cairo+X11, polls for clicks, fetches `com.canonical.dbusmenu.GetLayout` for the popup menu, fires `com.canonical.dbusmenu.Event` on click.
Build is gated on `linux && !386`; the 386 build (no cgo) and non-Linux builds use the `tray_watcher_other.go` no-op.
## Wails Dialogs (frontend, `@wailsio/runtime`)
The frontend dialog API lives in `@wailsio/runtime` as `Dialogs`. Authoritative signatures are in
`frontend/node_modules/@wailsio/runtime/types/dialogs.d.ts`.
### Message dialogs
```ts
import { Dialogs } from "@wailsio/runtime";
await Dialogs.Info({ Title, Message, Buttons?, Detached? });
await Dialogs.Warning({ Title, Message, Buttons?, Detached? });
await Dialogs.Error({ Title, Message, Buttons?, Detached? });
await Dialogs.Question({ Title, Message, Buttons?, Detached? });
```
All four return `Promise<string>` resolving to the **Label** of the button the user clicked. With no `Buttons` provided you get a single OK button — the promise just resolves when the user dismisses.
`MessageDialogOptions` fields:
- `Title?: string` — window title (short).
- `Message?: string` — the body text.
- `Buttons?: Button[]` — custom buttons. Each `Button` is `{ Label?, IsCancel?, IsDefault? }`. `IsCancel` is what Esc/⌘. triggers; `IsDefault` is what Enter triggers.
- `Detached?: boolean` — when `true`, the dialog isn't tied to the parent window (no sheet behavior on macOS).
### File dialogs
`Dialogs.OpenFile(options)` and `Dialogs.SaveFile(options)` — see `dialogs.d.ts` for the full `OpenFileDialogOptions` / `SaveFileDialogOptions` field set (filters, ButtonText, multi-select, hidden files, alias resolution, directory mode, etc).
### Per-OS behavior
| Platform | Behavior |
|---|---|
| **macOS** | Sheet-style when attached to a parent window. Up to ~4 custom buttons render naturally. Keyboard: Enter = default, ⌘. or Esc = cancel. Follows system theme. Accessibility is built-in. |
| **Windows** | Modal `TaskDialog`-style. Standard button labels are nudged toward OS conventions. Keyboard: Enter = default, Esc = cancel. Follows system theme. |
| **Linux** | GTK dialogs — appearance varies by desktop environment (GNOME/KDE). Follows desktop theme. Standard keyboard nav. |
Behavioural notes that affect us:
- The promise resolves with the **button label string**, not an index. Compare against the literal `Label` you passed (e.g. `if (result !== "Delete") return;`).
- `Buttons[]` on Linux/Windows uses the labels you supply, but the OS layout/styling is fixed.
- `Dialogs.Error` plays the platform error sound and uses the platform error icon. Don't use it for confirmations — use `Dialogs.Warning` or `Dialogs.Question`.
- Don't fire dialogs in a tight loop or from every keystroke — they interrupt focus and (on macOS) animate in/out. Debounce or guard with a `busy` flag.
### Frameless / custom-window dialogs (Go side)
When the native dialog API isn't enough — rich content, embedded webview, multi-screen flow — open a regular Wails window. This is done on the **Go side** via `app.Window.NewWithOptions(application.WebviewWindowOptions{...})`. Useful options:
- `Parent` — attach to a parent so OS treats it as a child.
- `AlwaysOnTop: true` — float above the parent.
- `Frameless: true` — no titlebar/chrome.
- `Resizable: false` (also `DisableResize: true` in v3) — fixed-size dialog feel.
- `Hidden: true` initially, then `dialog.Show()` + `dialog.SetFocus()`.
We **do** use this pattern, but pragmatically: `WindowManager.OpenSettings` and `OpenBrowserLogin` are regular small webview windows (not modal sheets) with no resize, hidden minimise/maximise buttons, and a translucent macOS title bar. They're not classic "OS modal dialogs"; they're just lightweight ancillary windows that look the part. Modal behaviour (`parent.SetEnabled(false)`) is intentionally not used — the user can still click back to the main window.
In-app modals (`NewProfileDialog`, delete-profile confirmation, etc.) are Radix `Dialog` primitives inside the main webview. Reach for a custom OS window only when content must escape the main window (BrowserLogin is the canonical example — its lifecycle is tied to the SSO wait) or when the window needs its own taskbar entry / dock icon.
## Conventions in this codebase
### Errors → native dialogs
We surface user-actionable errors via `Dialogs.Error` rather than red inline text. This started with the profile selector and applies broadly to operation failures (config save, profile switch, debug bundle, update, etc.).
Pattern:
```ts
try {
await SomeSvc.Operation(...);
} catch (e) {
await Dialogs.Error({
Title: "Operation Failed", // short, action-named
Message: e instanceof Error ? e.message : String(e),
});
}
```
Title rules:
- Action-named, short: "Switch Profile Failed", "Save Settings Failed", "Debug Bundle Failed".
- Not "Error" / "Something went wrong" — the dialog already says that visually.
When **not** to use a native dialog:
- **Form validation** (`Input.tsx`, URL-format checks, etc.) — inline next to the field. Native dialogs are too heavy for keystroke-driven feedback.
- **Status/result chrome on a dedicated screen** — e.g. the `/update` and `/login` pages can show a brief "Update failed" header *in addition to* the dialog, so the screen isn't blank after dismissal.
- **Transient link errors on the dashboard** (e.g. `link.error` on a management/signal card) — these flap in/out as the daemon recovers; an inline indicator is more appropriate than a dialog.
- **Result notifications inside a success flow** — e.g. "bundle saved but upload failed" can stay inline since the operation otherwise succeeded.
### Confirmations
Use `Dialogs.Warning` with explicit `Buttons`:
```ts
const r = await Dialogs.Warning({
Title: "Delete Profile",
Message: `Are you sure you want to delete "${name}"?`,
Buttons: [
{ Label: "Cancel", IsCancel: true },
{ Label: "Delete", IsDefault: true },
],
});
if (r !== "Delete") return;
```
Compare against the **Label string** returned, not an index.
### OS notifications
The tray uses Wails' built-in `notifications` service (`github.com/wailsapp/wails/v3/pkg/services/notifications`). One `notifications.NotificationService` is created in `main.go` and passed into `TrayServices.Notifier`. Notification IDs are prefixed for coalescing (`netbird-update-<version>`, `netbird-event-<id>`, `netbird-tray-error`, `netbird-session-expired`).
OS notifications are gated by the user's "Notifications" toggle (cached in `Tray.notificationsEnabled`, seeded from `Settings.GetConfig` at boot). `Severity == "critical"` events bypass the gate, mirroring the legacy Fyne event.Manager.
### Bindings & types
Always import generated bindings from `@bindings/services` and types from `@bindings/services/models.js`. The path alias is set up in `tsconfig.json` / `vite.config.ts`.
After editing any `services/*.go` (or the underlying proto), regenerate:
```
wails3 generate bindings -clean=true -ts
```
### Profile context
`modules/profile/ProfileContext.tsx` is the React-side source of truth for `username`, `activeProfile`, and the `profiles` list. It exposes `switchProfile`, `addProfile`, `removeProfile`, `logoutProfile`, and `refresh`.
Two important nuances:
1. **Two switch paths exist.** `screens/Profiles.tsx` calls `ProfileSwitcher.SwitchActive` (the Go-side single-source-of-truth path that also drives the optimistic-Connecting paint and the Peers suppression filter). `ProfileContext.switchProfile`, used elsewhere, still implements the reconnect policy in TS: it calls `Profiles.Switch` and, only if the daemon was actively online, follows up with `Connection.Down` + `Connection.Up`. The TS path skips `Peers.BeginProfileSwitch` so it won't paint optimistic Connecting through the tray. Prefer `ProfileSwitcher.SwitchActive` for new call sites.
2. **Don't call `Connection.Up` on an Idle/NeedsLogin daemon.** The daemon's internal 50s `waitForUp` will block until `DeadlineExceeded`. Both switch paths gate `Up` on a previously-online status (Connected/Connecting). Callers should not bring the connection up themselves outside this flow — `Connection.Up` is reserved for the explicit Connect button and the post-switch resume.
## Build / dev tasks
- `task dev` — Wails dev mode (live reload).
- `task build` — production build for the current OS (`Taskfile.yml` dispatches to `build/{darwin,linux,windows}/Taskfile.yml`).
- `task build:server` / `task run:server` / `task build:docker` / `task run:docker` — server-mode (HTTP, no GUI) variants. See `build/Taskfile.yml`.
- `task generate:bindings` does **not** exist as a top-level alias — run `wails3 generate bindings -clean=true -ts` directly from this directory.
CLI flags (parsed in `main.go`):
- `--daemon-addr <addr>` — gRPC address, default per `DaemonAddr()` (Unix socket on Linux/macOS, `tcp://127.0.0.1:41731` on Windows).
- `--log-file <target>` — repeatable. Each value is `console`, `syslog`, or a file path. First user-provided value drops the seeded `console` default.
- `--log-level <level>``trace|debug|info|warn|error` (default `info`).
## Useful references
- Wails v3 dialog docs: https://v3.wails.io/features/dialogs/message/ and https://v3.wails.io/features/dialogs/custom/ (may 403 from some clients).
- Wails v3 multiple-windows guidance: https://v3.wails.io/learn/multiple-windows/
- Authoritative TS signatures: `frontend/node_modules/@wailsio/runtime/types/dialogs.d.ts`.
- Wails examples: https://github.com/wailsapp/wails/tree/master/v3/examples/dialogs

BIN
client/ui/Netbird.icns Normal file

Binary file not shown.

View File

@@ -1,100 +0,0 @@
# NetBird desktop UI (Wails3 + React)
Replaces `client/ui` (Fyne). One binary on Windows / macOS / Linux,
talks to the NetBird daemon over gRPC, renders a React frontend in a
WebView.
## Prerequisites
- Go ≥ 1.25, Node ≥ 20, **pnpm** (`corepack enable && corepack prepare pnpm@latest --activate`)
- `wails3` CLI: `go install github.com/wailsapp/wails/v3/cmd/wails3@latest`
- `task`: `go install github.com/go-task/task/v3/cmd/task@latest`
- A running NetBird daemon (default: `unix:///var/run/netbird.sock`,
Windows `tcp://127.0.0.1:41731`)
- Linux only: `libwebkit2gtk-4.1-dev`, `libgtk-3-dev`,
`libayatana-appindicator3-dev`
## Develop without rebuilding
```bash
cd client/ui
task dev
```
`task dev` runs Vite (port 9245) + the Go binary + a `*.go` watcher.
Frontend edits hot-reload instantly. Go edits trigger a rebuild and
relaunch. Pass daemon flags after `--`:
```bash
task dev -- --daemon-addr=tcp://127.0.0.1:41731
```
For pure UI work (no native window, fastest loop):
```bash
cd frontend && pnpm dev
```
## Production build
```bash
task build
```
Output in `bin/`. Frontend assets are embedded into the binary.
### Cross-compile Windows from Linux
Install the mingw-w64 toolchain once:
```bash
sudo apt install gcc-mingw-w64-x86-64 # Debian/Ubuntu
sudo dnf install mingw64-gcc # Fedora
sudo pacman -S mingw-w64-gcc # Arch
```
Then:
```bash
CGO_ENABLED=1 task windows:build
```
Produces `bin/netbird-ui.exe`. macOS cross-compile from Linux is not
supported (signing and notarization need a real Mac).
### Windows console build (logs in the terminal)
Default `windows:build` links the binary as a Windows GUI app, which
detaches from the launching console — `logrus` output, `fmt.Println`,
and panics go nowhere visible. To debug tray/event/daemon issues:
```bash
CGO_ENABLED=1 task windows:build:console
```
Produces `bin/netbird-ui-console.exe`. Run it from `cmd.exe` /
PowerShell / Windows Terminal and stdout/stderr land in that
terminal. Same flag works on a native Windows build (drop the
`CGO_ENABLED=1` if your toolchain already has it set).
## Regenerating bindings
When a Go service signature changes:
```bash
wails3 generate bindings
```
`task dev` does this automatically on `*.go` save.
## Tray icons
Source SVGs live in `assets/svg/` (state.svg + state-macos.svg). After editing
any SVG, rasterize to the PNGs the Go side embeds:
```bash
task common:generate:tray:icons
```
Requires Inkscape. Commit the resulting `assets/*.png` files alongside the
SVG change so CI doesn't need Inkscape installed.

View File

@@ -1,58 +0,0 @@
version: '3'
includes:
common: ./build/Taskfile.yml
windows: ./build/windows/Taskfile.yml
darwin: ./build/darwin/Taskfile.yml
linux: ./build/linux/Taskfile.yml
vars:
APP_NAME: "netbird-ui"
BIN_DIR: "bin"
VITE_PORT: '{{.WAILS_VITE_PORT | default 9245}}'
tasks:
build:
summary: Builds the application
cmds:
- task: "{{OS}}:build"
package:
summary: Packages a production build of the application
cmds:
- task: "{{OS}}:package"
run:
summary: Runs the application
cmds:
- task: "{{OS}}:run"
dev:
summary: Runs the application in development mode
cmds:
- wails3 dev -config ./build/config.yml -port {{.VITE_PORT}}
setup:docker:
summary: Builds Docker image for cross-compilation (~800MB download)
cmds:
- task: common:setup:docker
build:server:
summary: Builds the application in server mode (no GUI, HTTP server only)
cmds:
- task: common:build:server
run:server:
summary: Runs the application in server mode
cmds:
- task: common:run:server
build:docker:
summary: Builds a Docker image for server mode deployment
cmds:
- task: common:build:docker
run:docker:
summary: Builds and runs the Docker image
cmds:
- task: common:run:docker

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 452 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 452 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 433 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 483 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 475 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 456 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.6 KiB

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.6 KiB

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

After

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.5 KiB

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.3 KiB

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

View File

@@ -1,10 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" fill="none">
<g transform="translate(0.5 4.5) scale(1.0)" opacity="0.7">
<path d="M21.4631 0.523438C17.8173 0.857913 16.0028 2.95675 15.3171 4.01871L4.66406 22.4734H17.5163L30.1929 0.523438H21.4631Z" fill="#F68330"/>
<path d="M17.5265 22.4737L0 3.88525C0 3.88525 19.8177 -1.44128 21.7493 15.1738L17.5265 22.4737Z" fill="#F68330"/>
<path d="M14.9236 4.70563L9.54688 14.0208L17.5158 22.4747L21.7385 15.158C21.0696 9.44682 18.2851 6.32784 14.9236 4.69727" fill="#F05252"/>
</g>
<circle cx="25" cy="25" r="7" fill="white"/>
<path d="M 22.6 24.5 v -1.4 a 2.4 2.4 0 0 1 4.8 0 v 1.4" fill="none" stroke="#D97706" stroke-width="1.5" stroke-linecap="round"/>
<rect x="21.6" y="24.5" width="6.8" height="4.9" rx="0.9" fill="none" stroke="#D97706" stroke-width="1.5"/>
</svg>

Before

Width:  |  Height:  |  Size: 847 B

Some files were not shown because too many files have changed in this diff Show More