mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-02 22:19:54 +00:00
Compare commits
1 Commits
ui-tray-li
...
peer-acl-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38603c7552 |
11
.github/workflows/golang-test-darwin.yml
vendored
11
.github/workflows/golang-test-darwin.yml
vendored
@@ -45,13 +45,4 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
|
||||
# which fails to compile until the frontend has been built. The Wails UI
|
||||
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
|
||||
# before goreleaser.
|
||||
# `go list -e` lets the listing succeed even though the embed fails to
|
||||
# resolve; the grep then drops the broken package by path. Without -e,
|
||||
# go list aborts with empty stdout and `go test` falls back to the repo
|
||||
# root, which has no Go files.
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
|
||||
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
16
.github/workflows/golang-test-linux.yml
vendored
16
.github/workflows/golang-test-linux.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
@@ -145,7 +145,7 @@ jobs:
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
@@ -158,15 +158,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
|
||||
# which fails to compile until the frontend has been built. The Wails UI
|
||||
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
|
||||
# before goreleaser.
|
||||
# `go list -e` lets the listing succeed even though the embed fails to
|
||||
# resolve; the grep then drops the broken package by path. Without -e,
|
||||
# go list aborts with empty stdout and `go test` falls back to the repo
|
||||
# root, which has no Go files.
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags "devcert integration" -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -228,7 +220,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -e -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags "devcert integration" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
|
||||
9
.github/workflows/golang-test-windows.yml
vendored
9
.github/workflows/golang-test-windows.yml
vendored
@@ -65,15 +65,8 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- name: Generate test script
|
||||
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
|
||||
# which fails to compile until the frontend has been built. The Wails UI
|
||||
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
|
||||
# before goreleaser.
|
||||
# `go list -e` lets the listing succeed even though the embed fails to
|
||||
# resolve; the Where-Object pipeline then drops the broken package by
|
||||
# path. Without -e, go list aborts with empty stdout.
|
||||
run: |
|
||||
$packages = go list -e ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } | Where-Object { $_ -notmatch '/client/ui' }
|
||||
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||
|
||||
17
.github/workflows/golangci-lint.yml
vendored
17
.github/workflows/golangci-lint.yml
vendored
@@ -22,11 +22,7 @@ jobs:
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
# Non-English UI translations trip codespell on real foreign words
|
||||
# (de: "Sie", "oder", "ist"). Only en/common.json is the source of
|
||||
# truth that should be spell-checked. Add each new locale dir here
|
||||
# when a language is added under client/ui/i18n/locales/.
|
||||
skip: go.mod,go.sum,**/proxy/web/**,**/pnpm-lock.yaml,**/package-lock.json,client/ui/i18n/locales/de/**,client/ui/i18n/locales/hu/**
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -58,16 +54,7 @@ jobs:
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Stub Wails frontend bundle
|
||||
# client/ui/main.go has //go:embed all:frontend/dist. The
|
||||
# directory is produced by `pnpm run build` and is gitignored, so
|
||||
# lint-only runs (no frontend toolchain) need a placeholder file
|
||||
# for the embed pattern to match.
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p client/ui/frontend/dist
|
||||
touch client/ui/frontend/dist/.embed-placeholder
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
|
||||
79
.github/workflows/release.yml
vendored
79
.github/workflows/release.yml
vendored
@@ -194,9 +194,9 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
@@ -356,18 +356,8 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
|
||||
- name: Set up pnpm
|
||||
uses: pnpm/action-setup@v3
|
||||
with:
|
||||
version: 11
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev gcc-mingw-w64-x86-64
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
|
||||
- name: Decode GPG signing key
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
@@ -386,16 +376,10 @@ jobs:
|
||||
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Install wails3 CLI
|
||||
# Version derived from go.mod so the binding generator always matches
|
||||
# the wails runtime the binary links against.
|
||||
run: |
|
||||
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
|
||||
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
@@ -463,20 +447,6 @@ jobs:
|
||||
run: go mod tidy
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
- name: Set up pnpm
|
||||
uses: pnpm/action-setup@v3
|
||||
with:
|
||||
version: 11
|
||||
- name: Install wails3 CLI
|
||||
# Version derived from go.mod so the binding generator always matches
|
||||
# the wails runtime the binary links against.
|
||||
run: |
|
||||
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
|
||||
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
@@ -564,6 +534,23 @@ jobs:
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
|
||||
destination: ${{ env.downloadPath }}\mesa3d.7z
|
||||
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
|
||||
|
||||
- name: Move opengl32.dll into dist (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
@@ -586,28 +573,6 @@ jobs:
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Set up Go for wails3 CLI
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Install wails3 CLI
|
||||
# Version derived from go.mod so the bootstrapper payload always
|
||||
# matches the wails runtime the binary links against.
|
||||
shell: bash
|
||||
run: |
|
||||
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
|
||||
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
|
||||
|
||||
- name: Stage WebView2 bootstrapper for installers
|
||||
# Both client/installer.nsis and client/netbird.wxs reference
|
||||
# client/MicrosoftEdgeWebview2Setup.exe. wails3 writes it there.
|
||||
# The signing pipeline (netbirdio/sign-pipelines) does the same
|
||||
# step for release builds; this mirrors it for PR sanity testing.
|
||||
shell: bash
|
||||
run: wails3 generate webview2bootstrapper -dir client
|
||||
|
||||
- name: Build NSIS installer
|
||||
shell: pwsh
|
||||
env:
|
||||
|
||||
2
.github/workflows/wasm-build-validation.yml
vendored
2
.github/workflows/wasm-build-validation.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
|
||||
@@ -114,16 +114,6 @@ linters:
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1012"
|
||||
# client/ui/main.go uses //go:embed all:frontend/dist; the
|
||||
# directory is populated by `pnpm build` in the release pipeline
|
||||
# and missing at lint time, so the embed parses to "no matching
|
||||
# files found" — surfaced by golangci-lint's typecheck pre-pass.
|
||||
# Suppress just that one diagnostic; the rest of the package
|
||||
# (services/, tray.go, grpc.go, ...) still gets linted normally.
|
||||
- linters:
|
||||
- typecheck
|
||||
path: client/ui/main\.go
|
||||
text: "pattern all:frontend/dist"
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
|
||||
before:
|
||||
hooks:
|
||||
# Bindings are gitignored; regenerate before the frontend build so
|
||||
# the @wailsio/runtime Vite plugin can resolve them (vite refuses to
|
||||
# build without them).
|
||||
- sh -c 'cd client/ui && wails3 generate bindings -clean=true -ts'
|
||||
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
|
||||
|
||||
builds:
|
||||
- id: netbird-ui
|
||||
dir: client/ui
|
||||
@@ -79,15 +70,12 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/build/linux/netbird.desktop
|
||||
- src: client/ui/build/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/build/appicon.png
|
||||
- src: client/ui/assets/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
- libgtk-3-0
|
||||
- libwebkit2gtk-4.1-0
|
||||
- libayatana-appindicator3-1
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
@@ -101,15 +89,12 @@ nfpms:
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/build/linux/netbird.desktop
|
||||
- src: client/ui/build/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/build/appicon.png
|
||||
- src: client/ui/assets/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
- gtk3
|
||||
- webkit2gtk4.1
|
||||
- libayatana-appindicator-gtk3
|
||||
rpm:
|
||||
signature:
|
||||
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
|
||||
before:
|
||||
hooks:
|
||||
# Bindings are gitignored; regenerate before the frontend build so
|
||||
# the @wailsio/runtime Vite plugin can resolve them (vite refuses to
|
||||
# build without them).
|
||||
- sh -c 'cd client/ui && wails3 generate bindings -clean=true -ts'
|
||||
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
|
||||
|
||||
builds:
|
||||
- id: netbird-ui-darwin
|
||||
dir: client/ui
|
||||
@@ -29,6 +20,8 @@ builds:
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird-ui-darwin
|
||||
|
||||
@@ -22,19 +22,11 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// extendSessionFlag drives the `netbird login --extend` flow: refresh the
|
||||
// SSO session expiry on the management server without tearing down the
|
||||
// tunnel. Mutually exclusive with setup-key login (a setup-key cannot
|
||||
// refresh an SSO-tracked peer — see auth.errSetupKeyOnSSOExpiredPeer).
|
||||
var extendSessionFlag bool
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
loginCmd.PersistentFlags().BoolVar(&extendSessionFlag, "extend", false,
|
||||
"refresh the SSO session expiry without tearing down the tunnel (requires an active connection)")
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
@@ -69,16 +61,6 @@ var loginCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
if extendSessionFlag {
|
||||
if providedSetupKey != "" {
|
||||
return fmt.Errorf("--extend cannot be combined with a setup key; setup keys can only enrol new peers")
|
||||
}
|
||||
if err := doExtendSession(ctx, cmd); err != nil {
|
||||
return fmt.Errorf("extend session failed: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// workaround to run without service
|
||||
if util.FindFirstLogPath(logFiles) == "" {
|
||||
if err := doForegroundLogin(ctx, cmd, providedSetupKey, activeProf); err != nil {
|
||||
@@ -168,65 +150,6 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
return nil
|
||||
}
|
||||
|
||||
// doExtendSession drives the daemon's RequestExtendAuthSession /
|
||||
// WaitExtendAuthSession pair. The user is sent through a regular SSO flow
|
||||
// (browser + verification URL) and the resulting JWT is forwarded to the
|
||||
// management server's ExtendAuthSession RPC. The tunnel stays up
|
||||
// throughout — no Down/Up, no network-map resync.
|
||||
func doExtendSession(ctx context.Context, cmd *cobra.Command) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
req := &proto.RequestExtendAuthSessionRequest{}
|
||||
// Pre-fill the IdP login hint from the active profile so the user
|
||||
// doesn't have to retype their email. Best-effort: we still proceed
|
||||
// without a hint if the lookup fails.
|
||||
pm := profilemanager.NewProfileManager()
|
||||
if active, perr := pm.GetActiveProfile(); perr == nil {
|
||||
if profState, sperr := pm.GetProfileState(active.Name); sperr == nil && profState.Email != "" {
|
||||
req.Hint = &profState.Email
|
||||
}
|
||||
}
|
||||
|
||||
startResp, err := client.RequestExtendAuthSession(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("start extend session: %v", err)
|
||||
}
|
||||
|
||||
uri := startResp.GetVerificationURIComplete()
|
||||
if uri == "" {
|
||||
uri = startResp.GetVerificationURI()
|
||||
}
|
||||
openURL(cmd, uri, startResp.GetUserCode(), noBrowser, showQR)
|
||||
|
||||
waitResp, err := client.WaitExtendAuthSession(ctx, &proto.WaitExtendAuthSessionRequest{
|
||||
DeviceCode: startResp.GetDeviceCode(),
|
||||
UserCode: startResp.GetUserCode(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("wait for extend session: %v", err)
|
||||
}
|
||||
|
||||
if ts := waitResp.GetSessionExpiresAt(); ts.IsValid() && !ts.AsTime().IsZero() {
|
||||
deadline := ts.AsTime().Local()
|
||||
cmd.Printf("Session extended. New expiry: %s\n", deadline.Format("2006-01-02 15:04:05 MST"))
|
||||
} else {
|
||||
// Management reported the peer is not eligible (e.g. login
|
||||
// expiration disabled on the account). Surface that fact
|
||||
// instead of pretending the call succeeded.
|
||||
cmd.Println("Session extension call completed, but the management server did not return a new deadline (peer may not be SSO-tracked or login expiration is disabled).")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) (*profilemanager.Profile, error) {
|
||||
// switch profile if provided
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
@@ -118,11 +117,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var sessionExpiresAt time.Time
|
||||
if ts := resp.GetSessionExpiresAt(); ts.IsValid() {
|
||||
sessionExpiresAt = ts.AsTime().UTC()
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||
Anonymize: anonymizeFlag,
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
@@ -133,7 +127,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
IPsFilter: ipsFilterMap,
|
||||
ConnectionTypeFilter: connectionTypeFilter,
|
||||
ProfileName: profName,
|
||||
SessionExpiresAt: sessionExpiresAt,
|
||||
})
|
||||
var statusOutputString string
|
||||
switch {
|
||||
|
||||
@@ -464,7 +464,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
_ = engine.RunHealthProbes(context.Background(), false)
|
||||
_ = engine.RunHealthProbes(false)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
fm, err := uspfilter.Create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -29,47 +29,80 @@ const (
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// SkipNftablesEnv is the environment variable to skip nftables check
|
||||
const SkipNftablesEnv = "NB_SKIP_NFTABLES_CHECK"
|
||||
|
||||
// errNoFirewallManager indicates no kernel firewall backend is present,
|
||||
// as opposed to a backend that exists but failed to create or initialize.
|
||||
var errNoFirewallManager = errors.New("no firewall manager found")
|
||||
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
|
||||
// We run in userspace mode and force userspace firewall was requested.
|
||||
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
|
||||
nativeFw, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding without it", err)
|
||||
}
|
||||
|
||||
log.Info("forcing userspace firewall")
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
return createUserspaceFirewall(iface, nativeFw, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
|
||||
// Kernel cannot fall back to anything else, need to return error
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
}
|
||||
|
||||
// Fall back to the userspace packet filter if native is unavailable
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
switch {
|
||||
case err == nil && !iface.IsUserspaceBind():
|
||||
// Nothing to do, fall through
|
||||
case err == nil && iface.IsUserspaceBind():
|
||||
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||
// needs a device filter for DNS interception hooks. Install a minimal
|
||||
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||
}
|
||||
case err != nil && !iface.IsUserspaceBind():
|
||||
// Kernel cannot fall back to anything else, need to return error
|
||||
return nil, err
|
||||
case err != nil && iface.IsUserspaceBind():
|
||||
// Fall back to the userspace packet filter if native is unavailable
|
||||
logNativeFirewallUnavailable(err)
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||
// needs a device filter for DNS interception hooks. Install a minimal
|
||||
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// createUserspaceFirewall builds the userspace packet filter, optionally
|
||||
// backed by a native firewall, and allows netbird interface traffic.
|
||||
func createUserspaceFirewall(iface IFaceMapper, nativeFw firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
||||
fm, err := uspfilter.Create(iface, nativeFw, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// logNativeFirewallUnavailable logs the fallback to userspace at info level
|
||||
// when no kernel firewall backend exists, and at warn level otherwise.
|
||||
func logNativeFirewallUnavailable(err error) {
|
||||
if errors.Is(err, errNoFirewallManager) {
|
||||
log.Infof("no native firewall backend available: %v. Proceeding with userspace", err)
|
||||
} else {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
return nil, fmt.Errorf("create firewall: %w", err)
|
||||
}
|
||||
|
||||
if err = fm.Init(stateManager); err != nil {
|
||||
@@ -88,29 +121,10 @@ func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
|
||||
log.Info("creating an nftables firewall manager")
|
||||
return nbnftables.Create(iface, mtu)
|
||||
default:
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
return nil, errors.New("no firewall manager found")
|
||||
return nil, errNoFirewallManager
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
|
||||
}
|
||||
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() FWType {
|
||||
useIPTABLES := false
|
||||
@@ -132,35 +146,38 @@ func check() FWType {
|
||||
}
|
||||
}
|
||||
|
||||
nf := nftables.Conn{}
|
||||
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
if !useIPTABLES {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
// search for chains where table is filter
|
||||
// if we find one, we assume that nftables manager can be used with iptables
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name == "filter" {
|
||||
// Honor the skip env before probing nftables at all.
|
||||
if os.Getenv(SkipNftablesEnv) != "true" {
|
||||
nf := nftables.Conn{}
|
||||
if chains, err := nf.ListChains(); err == nil {
|
||||
if !useIPTABLES {
|
||||
return NFTABLES
|
||||
}
|
||||
}
|
||||
|
||||
// check tables for the following constraints:
|
||||
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
||||
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
||||
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
||||
// 4. if we find an error we log and continue with iptables check
|
||||
nbTablesList, err := nf.ListTables()
|
||||
switch {
|
||||
case err == nil && len(iptablesChains) > 0:
|
||||
return IPTABLES
|
||||
case err == nil && len(nbTablesList) != 1:
|
||||
return NFTABLES
|
||||
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
||||
return IPTABLES
|
||||
case err != nil:
|
||||
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
||||
// search for chains where table is filter
|
||||
// if we find one, we assume that nftables manager can be used with iptables
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name == "filter" {
|
||||
return NFTABLES
|
||||
}
|
||||
}
|
||||
|
||||
// check tables for the following constraints:
|
||||
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
||||
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
||||
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
||||
// 4. if we find an error we log and continue with iptables check
|
||||
nbTablesList, err := nf.ListTables()
|
||||
switch {
|
||||
case err == nil && len(iptablesChains) > 0:
|
||||
return IPTABLES
|
||||
case err == nil && len(nbTablesList) != 1:
|
||||
return NFTABLES
|
||||
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
||||
return IPTABLES
|
||||
case err != nil:
|
||||
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,554 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
tableName = "filter"
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
|
||||
// external DNAT from bypassing ACL rules.
|
||||
mangleFwdKey = "MANGLE-FORWARD"
|
||||
)
|
||||
|
||||
type aclEntries map[string][][]string
|
||||
|
||||
type entry struct {
|
||||
spec []string
|
||||
position int
|
||||
}
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
v6 bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||
return &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
m.stateManager = stateManager
|
||||
|
||||
m.seedInitialEntries()
|
||||
m.seedInitialOptionalEntries()
|
||||
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
if err := m.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default chains: %w", err)
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
chain := chainNameInputRules
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
||||
if m.v6 && ipsetName != "" {
|
||||
ipsetName += "-v6"
|
||||
}
|
||||
proto := protoForFamily(protocol, m.v6)
|
||||
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
||||
|
||||
mangleSpecs := slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
)
|
||||
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
ipList.addIP(ip.String())
|
||||
return []firewall.Rule{&Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
specs: specs,
|
||||
v6: m.v6,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := m.createIPSet(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("create ipset: %w", err)
|
||||
}
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
ipList := newIpList(ip.String())
|
||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||
}
|
||||
|
||||
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
if action == firewall.ActionDrop {
|
||||
// Insert at the beginning of the chain (position 1)
|
||||
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
|
||||
} else {
|
||||
err = m.iptablesClient.Append(tableFilter, chain, specs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to add mangle rule: %v", err)
|
||||
mangleSpecs = nil
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
mangleSpecs: mangleSpecs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
v6: m.v6,
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
shouldDestroyIpset := false
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
ip := net.ParseIP(r.ip)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("parse IP %s", r.ip)
|
||||
}
|
||||
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(ipsetList.ips, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(ipsetList.ips) != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||
shouldDestroyIpset = true
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||
}
|
||||
|
||||
if r.mangleSpecs != nil {
|
||||
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDestroyIpset {
|
||||
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy empty ipset: %v", err)
|
||||
} else {
|
||||
log.Errorf("destroy empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) Reset() error {
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["INPUT"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range m.entries["FORWARD"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range m.entries[mangleFwdKey] {
|
||||
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
m.ipsetStore.deleteIpset(ipsetName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) createDefaultChains() error {
|
||||
// chain netbird-acl-input-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
// mangle FORWARD guard rules are handled separately below
|
||||
if chainName == mangleFwdKey {
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for chainName, entries := range m.optionalEntries {
|
||||
for _, entry := range entries {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
|
||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||
continue
|
||||
}
|
||||
m.entries[chainName] = append(m.entries[chainName], entry.spec)
|
||||
}
|
||||
}
|
||||
clear(m.optionalEntries)
|
||||
|
||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||
for _, rule := range m.entries[mangleFwdKey] {
|
||||
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
|
||||
// We want to make sure our traffic is not dropped by existing rules.
|
||||
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
established := getConntrackEstablished()
|
||||
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||
|
||||
// Inbound is handled by our ACLs, the rest is dropped.
|
||||
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||
|
||||
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
|
||||
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
|
||||
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
|
||||
// ACL mark check where it cannot be overridden.
|
||||
m.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||
"-j", "ACCEPT",
|
||||
})
|
||||
m.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "DNAT",
|
||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
"-j", "DROP",
|
||||
})
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||
}
|
||||
|
||||
func (m *aclManager) updateState() {
|
||||
if m.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := m.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||
if v6 && protocol == firewall.ProtocolICMP {
|
||||
return "ipv6-icmp"
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
matchByIP := !ip.IsUnspecified()
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
specs = append(specs, applyPort("--sport", sPort)...)
|
||||
specs = append(specs, applyPort("--dport", dPort)...)
|
||||
return specs
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
if action == firewall.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
|
||||
if ipsetName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
actionSuffix := ""
|
||||
if action == firewall.ActionDrop {
|
||||
actionSuffix = "-drop"
|
||||
}
|
||||
|
||||
switch {
|
||||
case sPort != nil && dPort != nil:
|
||||
return ipsetName + "-sport-dport" + actionSuffix
|
||||
case sPort != nil:
|
||||
return ipsetName + "-sport" + actionSuffix
|
||||
case dPort != nil:
|
||||
return ipsetName + "-dport" + actionSuffix
|
||||
default:
|
||||
return ipsetName + actionSuffix
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if m.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
}
|
||||
|
||||
if err := ipset.Del(name, entry); err != nil {
|
||||
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) flushIPSet(name string) error {
|
||||
return ipset.Flush(name)
|
||||
}
|
||||
|
||||
func (m *aclManager) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
352
client/firewall/iptables/chains_linux.go
Normal file
352
client/firewall/iptables/chains_linux.go
Normal file
@@ -0,0 +1,352 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) createContainers() error {
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFwdIn, tableFilter},
|
||||
{chainRTFwdOut, tableFilter},
|
||||
{chainRTPre, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRdr, tableNat},
|
||||
{chainRTMSSClamp, tableMangle},
|
||||
} {
|
||||
// Fallback: clear chains that survived an unclean shutdown.
|
||||
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFwdIn); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFwdOut); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add static nat rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addJumpRules(); err != nil {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addJumpRules() error {
|
||||
// Jump to nat chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPostrouting, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat postrouting jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNATPost] = natRule
|
||||
|
||||
// Jump to mangle prerouting chain
|
||||
preRule := []string{"-j", chainRTPre}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainPrerouting, 1, preRule...); err != nil {
|
||||
return fmt.Errorf("add mangle prerouting jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpManglePre] = preRule
|
||||
|
||||
// Jump to nat prerouting chain
|
||||
rdrRule := []string{"-j", chainRTRdr}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPrerouting, 1, rdrRule...); err != nil {
|
||||
return fmt.Errorf("add nat prerouting jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNATPre] = rdrRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) setupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
preRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPrerouting, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePre] = preRule
|
||||
}
|
||||
|
||||
postRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPostrouting, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePost] = postRule
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// seedInitialEntries adds default rules to the entries map. Rules are
|
||||
// inserted at position 1, so the order here is reversed.
|
||||
//
|
||||
// Existing FORWARD policy decides outbound traffic towards our
|
||||
// interface. If FORWARD policy is "drop", we add an
|
||||
// established/related rule to allow return traffic for inbound rules.
|
||||
func (r *family) seedInitialEntries() {
|
||||
established := getConntrackEstablished()
|
||||
|
||||
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
||||
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", chainACLInput})
|
||||
r.appendToEntries(chainInput, append([]string{"-i", r.wgIface.Name()}, established...))
|
||||
|
||||
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
||||
r.appendToEntries(chainForward, []string{"-o", r.wgIface.Name(), "-j", chainRTFwdOut})
|
||||
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", chainRTFwdIn})
|
||||
|
||||
// Mangle FORWARD guard: when external DNAT redirects traffic from
|
||||
// the wg interface, it traverses FORWARD instead of INPUT,
|
||||
// bypassing ACL rules. ACCEPT rules in filter FORWARD can be
|
||||
// inserted above ours. Mangle runs before filter, so these guard
|
||||
// rules enforce the ACL mark check where it cannot be overridden.
|
||||
r.appendToEntries(mangleForwardKey, []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||
"-j", "ACCEPT",
|
||||
})
|
||||
r.appendToEntries(mangleForwardKey, []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "DNAT",
|
||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
"-j", "DROP",
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) seedInitialOptionalEntries() {
|
||||
r.optionalEntries[chainForward] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) appendToEntries(chain chainKey, spec ruleSpec) {
|
||||
r.entries[chain] = append(r.entries[chain], spec)
|
||||
}
|
||||
|
||||
func (r *family) createDefaultChains() error {
|
||||
if err := r.iptablesClient.NewChain(tableName, chainACLInput); err != nil {
|
||||
return fmt.Errorf("create %s chain: %w", chainACLInput, err)
|
||||
}
|
||||
|
||||
for chain, rules := range r.entries {
|
||||
// mangle FORWARD guard rules are handled separately below
|
||||
if chain == mangleForwardKey {
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if err := r.iptablesClient.InsertUnique(tableName, string(chain), 1, rule...); err != nil {
|
||||
return fmt.Errorf("insert jump rule into %s: %w", chain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for chain, entries := range r.optionalEntries {
|
||||
for _, entry := range entries {
|
||||
if err := r.iptablesClient.InsertUnique(tableName, string(chain), entry.position, entry.spec...); err != nil {
|
||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||
continue
|
||||
}
|
||||
r.entries[chain] = append(r.entries[chain], entry.spec)
|
||||
}
|
||||
}
|
||||
clear(r.optionalEntries)
|
||||
|
||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||
for _, rule := range r.entries[mangleForwardKey] {
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainForward, rule...); err != nil {
|
||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) cleanUpDefaultForwardRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
// cleanJumpRules removes the OUTPUT jump to NETBIRD-NAT-OUTPUT among
|
||||
// the others, so the chain below deletes cleanly instead of failing
|
||||
// with "device or resource busy".
|
||||
if err := r.cleanJumpRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("clean jump rules: %w", err))
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFwdIn, tableFilter},
|
||||
{chainRTFwdOut, tableFilter},
|
||||
{chainRTPre, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRdr, tableNat},
|
||||
{chainNATOutput, tableNat},
|
||||
{chainRTMSSClamp, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
|
||||
continue
|
||||
}
|
||||
if ok {
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanJumpRules() error {
|
||||
// locations maps each tracked jump rule to the built-in table and
|
||||
// chain it was inserted into.
|
||||
locations := map[firewall.RuleID]struct{ table, chain string }{
|
||||
jumpNATPost: {tableNat, chainPostrouting},
|
||||
jumpManglePre: {tableMangle, chainPrerouting},
|
||||
jumpNATPre: {tableNat, chainPrerouting},
|
||||
jumpMSSClamp: {tableMangle, chainForward},
|
||||
jumpNATOutput: {tableNat, chainOutput},
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for ruleID, loc := range locations {
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(loc.table, loc.chain, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete rule from chain %s in table %s: %w", loc.chain, loc.table, err))
|
||||
continue
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanAclChains() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.cleanInputAclChain(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanPreroutingEntries(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
for _, rule := range r.entries[mangleForwardKey] {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainForward, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete mangle %s guard rule %v: %w", chainForward, rule, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanInputAclChain() error {
|
||||
ok, err := r.iptablesClient.ChainExists(tableName, chainACLInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainACLInput, err)
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.entries[chainInput] {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableName, chainInput, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainInput, rule, err))
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range r.entries[chainForward] {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableName, chainForward, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainForward, rule, err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(tableName, chainACLInput); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("clear and delete %s chain: %w", chainACLInput, err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanPreroutingEntries() error {
|
||||
ok, err := r.iptablesClient.ChainExists(tableMangle, chainPrerouting)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s in %s: %w", chainPrerouting, tableMangle, err)
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.entries[chainPrerouting] {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainPrerouting, rule, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
if preRule, exists := r.rules[markManglePre]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePre)
|
||||
}
|
||||
}
|
||||
|
||||
if postRule, exists := r.rules[markManglePost]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPostrouting, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePost)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
285
client/firewall/iptables/dnat_linux.go
Normal file
285
client/firewall/iptables/dnat_linux.go
Normal file
@@ -0,0 +1,285 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
ruleID := rule.ID()
|
||||
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
toDestination := rule.TranslatedAddress.String()
|
||||
switch {
|
||||
case len(rule.TranslatedPort.Values) == 0:
|
||||
// no translated port, use original port
|
||||
case len(rule.TranslatedPort.Values) == 1:
|
||||
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||
// need the "/originalport" suffix to avoid dnat port randomization
|
||||
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||
}
|
||||
|
||||
proto := strings.ToLower(string(rule.Protocol))
|
||||
|
||||
rules := make(map[firewall.RuleID]ruleInfo, 3)
|
||||
|
||||
// DNAT rule
|
||||
dnatRule := []string{
|
||||
"!", "-i", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-j", "DNAT",
|
||||
"--to-destination", toDestination,
|
||||
}
|
||||
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
||||
rules[ruleID+dnatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRdr,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
// SNAT rule
|
||||
snatRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "MASQUERADE",
|
||||
}
|
||||
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleID+snatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTNAT,
|
||||
rule: snatRule,
|
||||
}
|
||||
|
||||
// Forward filtering rule, if fwd policy is DROP
|
||||
forwardRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "ACCEPT",
|
||||
}
|
||||
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleID+fwdSuffix] = ruleInfo{
|
||||
table: tableFilter,
|
||||
chain: chainRTFwdOut,
|
||||
rule: forwardRule,
|
||||
}
|
||||
|
||||
// Request forwarding once the rule is about to be installed, releasing
|
||||
// it if installation fails so the refcount tracks the real rules.
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||
log.Errorf("rollback failed: %v", rollbackErr)
|
||||
}
|
||||
r.releaseForwarding()
|
||||
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
||||
}
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *family) rollbackRules(rules map[firewall.RuleID]ruleInfo) error {
|
||||
var merr *multierror.Error
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
||||
// On rollback error, add to rules map for next cleanup
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
}
|
||||
if merr != nil {
|
||||
r.updateState()
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
||||
ruleID := rule.ID()
|
||||
|
||||
var merr *multierror.Error
|
||||
var found bool
|
||||
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||
found = true
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleID+dnatSuffix)
|
||||
}
|
||||
|
||||
if snatRule, exists := r.rules[ruleID+snatSuffix]; exists {
|
||||
found = true
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleID+snatSuffix)
|
||||
}
|
||||
|
||||
if fwdRule, exists := r.rules[ruleID+fwdSuffix]; exists {
|
||||
found = true
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFwdOut, fwdRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleID+fwdSuffix)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
// Release once, only if the rule was present and removed.
|
||||
if merr == nil && found {
|
||||
r.releaseForwarding()
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// releaseForwarding drops one IP forwarding reference, logging any error.
|
||||
func (r *family) releaseForwarding() {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("release IP forwarding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
||||
"--dport", strconv.Itoa(int(originalPort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||
}
|
||||
|
||||
info := ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRdr,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(info.table, info.chain, info.rule...); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = info.rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||
func (r *family) ensureNATOutputChain() error {
|
||||
if _, exists := r.rules[jumpNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
if !chainExists {
|
||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainOutput, 1, jumpRule...); err != nil {
|
||||
if !chainExists {
|
||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNATOutput] = jumpRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
|
||||
"--dport", strconv.Itoa(int(originalPort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
246
client/firewall/iptables/family_linux.go
Normal file
246
client/firewall/iptables/family_linux.go
Normal file
@@ -0,0 +1,246 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableName = tableFilter
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
|
||||
// chainACLInput is the peer ACL chain that holds installed
|
||||
// peer-filtering rules.
|
||||
chainACLInput = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// mangleForwardKey is the entries map key for mangle FORWARD guard
|
||||
// rules that prevent external DNAT from bypassing ACL rules.
|
||||
mangleForwardKey chainKey = "MANGLE-FORWARD"
|
||||
|
||||
chainInput = "INPUT"
|
||||
chainPostrouting = "POSTROUTING"
|
||||
chainPrerouting = "PREROUTING"
|
||||
chainForward = "FORWARD"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFwdIn = "NETBIRD-RT-FWD-IN"
|
||||
chainRTFwdOut = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPre = "NETBIRD-RT-PRE"
|
||||
chainRTRdr = "NETBIRD-RT-RDR"
|
||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||
chainRTMSSClamp = "NETBIRD-RT-MSSCLAMP"
|
||||
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNATPre = "jump-nat-pre"
|
||||
jumpNATPost = "jump-nat-post"
|
||||
jumpNATOutput = "jump-nat-output"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
|
||||
dnatSuffix firewall.RuleID = "_dnat"
|
||||
snatSuffix firewall.RuleID = "_snat"
|
||||
fwdSuffix firewall.RuleID = "_fwd"
|
||||
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
chain string
|
||||
table string
|
||||
rule []string
|
||||
}
|
||||
|
||||
type routeRules map[firewall.RuleID][]string
|
||||
|
||||
// ruleSpec is a single iptables rule expressed as its argument list
|
||||
// (e.g. {"-i", "wg0", "-j", "DROP"}).
|
||||
type ruleSpec []string
|
||||
|
||||
// chainKey identifies the chain a seeded entry belongs to. It holds
|
||||
// built-in chain names ("INPUT", "FORWARD", "PREROUTING") plus the
|
||||
// synthetic mangleForwardKey bucket for the mangle FORWARD guard rules.
|
||||
type chainKey string
|
||||
|
||||
// aclEntries maps a chain to the rules seeded into it to jump into or
|
||||
// guard the netbird ACL chains.
|
||||
type aclEntries map[chainKey][]ruleSpec
|
||||
|
||||
type entry struct {
|
||||
spec ruleSpec
|
||||
position int
|
||||
}
|
||||
|
||||
// ipsetCounter is the shared hash:net refcounter used by peer and
|
||||
// route ACLs alike. The ipset library does not support comments, so
|
||||
// the key is just the set name (string).
|
||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
|
||||
// family holds the per-address-family iptables state. One instance
|
||||
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
||||
// single family; the top-level Manager owns one for v4 and another
|
||||
// for v6.
|
||||
type family struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
v6 bool
|
||||
|
||||
// Peer ACL chain bookkeeping.
|
||||
entries aclEntries
|
||||
optionalEntries map[chainKey][]entry
|
||||
|
||||
// filters holds peer + route filter rules keyed by content hash.
|
||||
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
||||
filters map[nbid.RuleID]*Rule
|
||||
ipsetCounter *ipsetCounter
|
||||
|
||||
// rules holds NAT, jump, and MSS-clamping rules (auxiliary
|
||||
// plumbing that isn't a filter rule).
|
||||
rules routeRules
|
||||
|
||||
// Routing / NAT.
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newFamily(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*family, error) {
|
||||
r := &family{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
entries: make(aclEntries),
|
||||
optionalEntries: make(map[chainKey][]entry),
|
||||
filters: make(map[nbid.RuleID]*Rule),
|
||||
rules: make(routeRules),
|
||||
mtu: mtu,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
func(name string, sources []netip.Prefix) (struct{}, error) {
|
||||
return struct{}{}, r.createIpSet(name, sources)
|
||||
},
|
||||
func(name string, _ struct{}) error {
|
||||
return r.deleteIpSet(name)
|
||||
},
|
||||
)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// init wires the family to the state manager and installs both the
|
||||
// route ACL containers and the peer ACL chain skeleton.
|
||||
func (r *family) init(stateManager *statemanager.Manager) error {
|
||||
r.stateManager = stateManager
|
||||
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
r.seedInitialEntries()
|
||||
r.seedInitialOptionalEntries()
|
||||
|
||||
if err := r.cleanAclChains(); err != nil {
|
||||
return fmt.Errorf("clean acl chains: %w", err)
|
||||
}
|
||||
if err := r.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default chains: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset tears down all firewall state owned by this family. ACL
|
||||
// chain cleanup runs before route-chain cleanup because the route
|
||||
// chains are still referenced by FORWARD jumps installed during
|
||||
// seedInitialEntries; deleting them first would trip EBUSY.
|
||||
func (r *family) Reset() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.cleanAclChains(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
clear(r.rules)
|
||||
clear(r.filters)
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) updateState() {
|
||||
if r.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := r.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
currentState.ACLEntries6 = r.entries
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
currentState.ACLEntries = r.entries
|
||||
}
|
||||
|
||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
341
client/firewall/iptables/filter_linux.go
Normal file
341
client/firewall/iptables/filter_linux.go
Normal file
@@ -0,0 +1,341 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// AddFilterRule installs a packet-filtering rule. With destination
|
||||
// empty, the rule goes to the peer ACL input chain plus a paired
|
||||
// mangle PREROUTING rule for the redirect mark. With destination set
|
||||
// (prefix or named set), it goes to the route ACL forward chain.
|
||||
// Multi-source rules collapse to one iptables rule via the shared
|
||||
// hash:net ipset.
|
||||
func (r *family) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||
if existing, ok := r.filters[ruleID]; ok {
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
srcMatch, err := r.applySourceMatch(sourceNetwork(sources), sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply source match: %w", err)
|
||||
}
|
||||
|
||||
rule, err := r.installFilterRule(ruleID, srcMatch, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
r.dropSourceMatch(srcMatch)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.filters[ruleID] = rule
|
||||
r.updateState()
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *family) hasRule(id nbid.RuleID) bool {
|
||||
_, ok := r.filters[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
// hasDNATRule reports whether this family owns the DNAT rule set for
|
||||
// the given user id. DNAT rules live in r.rules under the well-known
|
||||
// "<id>_dnat" key; the lookup here is used by Manager.DeleteDNATRule
|
||||
// to pick the right family.
|
||||
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
||||
_, ok := r.rules[id+dnatSuffix]
|
||||
return ok
|
||||
}
|
||||
|
||||
// DeleteFilterRule removes a previously installed filter rule. The
|
||||
// rule's stored chain/table identify where to delete from; source set
|
||||
// references are recovered from the spec via findSets and dropped
|
||||
// from the shared ipset counter.
|
||||
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
||||
ruleID := rule.ID()
|
||||
pr, ok := r.filters[ruleID]
|
||||
if !ok {
|
||||
log.Debugf("filter rule %s not found", ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteIfExists keeps both deletes idempotent so a retry after a
|
||||
// partial failure does not error on the half that was already removed.
|
||||
var merr *multierror.Error
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, pr.chain, pr.specs...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete rule from %s: %w", pr.chain, err))
|
||||
}
|
||||
if pr.mangleSpecs != nil {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, pr.mangleSpecs...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete mangle rule: %w", err))
|
||||
}
|
||||
}
|
||||
if merr != nil {
|
||||
// Leave the rule tracked so the caller retries the remaining half.
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
r.dropSourceMatch(pr.specs)
|
||||
delete(r.filters, ruleID)
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// findSets scans an iptables rule spec for "-m set --match-set <name>
|
||||
// <dir>" fragments and returns the named sets in occurrence order.
|
||||
// Used at delete time to drop ipsetCounter references.
|
||||
func findSets(rule []string) []string {
|
||||
var sets []string
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
sets = append(sets, rule[i+3])
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
||||
// shape the rest of the spec-builder consumes: empty for match-any, a
|
||||
// single prefix inline, or an ipset for multiple sources.
|
||||
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
||||
switch {
|
||||
case len(sources) == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1:
|
||||
return firewall.Network{Prefix: sources[0]}
|
||||
default:
|
||||
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
||||
}
|
||||
}
|
||||
|
||||
// applySourceMatch returns the iptables match fragment for the rule's
|
||||
// source. For a Set it increments the shared ipset's refcount; for a
|
||||
// Prefix it emits a direct -s match; for the wildcard it returns nil.
|
||||
func (r *family) applySourceMatch(network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||
switch {
|
||||
case network.IsSet():
|
||||
if r.ipsetCounter == nil {
|
||||
return nil, fmt.Errorf("multi-source peer rule requires shared ipset counter")
|
||||
}
|
||||
name := r.ipsetName(network.Set.HashedName())
|
||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||
return nil, fmt.Errorf("ipset increment %s: %w", name, err)
|
||||
}
|
||||
return []string{"-m", "set", matchSet, name, "src"}, nil
|
||||
case network.IsPrefix():
|
||||
return []string{"-s", network.Prefix.String()}, nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// dropSourceMatch undoes whatever applySourceMatch reserved. Safe to
|
||||
// call when the spec is empty or holds only inline matchers. Decrement
|
||||
// errors are logged but not returned: the filter rule has already been
|
||||
// deleted at that point and we don't want to leak the deletion.
|
||||
func (r *family) dropSourceMatch(srcMatch []string) {
|
||||
if r.ipsetCounter == nil {
|
||||
return
|
||||
}
|
||||
for _, name := range findSets(srcMatch) {
|
||||
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
||||
log.Errorf("rollback ipset decrement %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decrementSetCounter drops ipset references owned by a raw rule spec
|
||||
// stored in r.rules (NAT / legacy route entries). It returns an error
|
||||
// aggregate so the caller surfaces decrement failures.
|
||||
func (r *family) decrementSetCounter(rule []string) error {
|
||||
if r.ipsetCounter == nil {
|
||||
return nil
|
||||
}
|
||||
var merr *multierror.Error
|
||||
for _, name := range findSets(rule) {
|
||||
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// installFilterRule assembles and writes one iptables filter-chain
|
||||
// rule. With destination empty the rule lands in the peer ACL input
|
||||
// chain and a paired mangle PREROUTING rule is added for the redirect
|
||||
// mark. With destination set the rule lands in the route ACL forward
|
||||
// chain and there is no mangle pairing.
|
||||
func (r *family) installFilterRule(
|
||||
ruleID nbid.RuleID,
|
||||
srcMatch []string,
|
||||
destination firewall.Network,
|
||||
protocol firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (*Rule, error) {
|
||||
isRoute := !destination.IsZero()
|
||||
|
||||
proto := protoForFamily(protocol, r.v6)
|
||||
|
||||
specs := slices.Clone(srcMatch)
|
||||
var destExp []string
|
||||
if isRoute {
|
||||
var err error
|
||||
destExp, err = r.applyNetwork("-d", destination, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
specs = append(specs, destExp...)
|
||||
}
|
||||
specs = append(specs, filterMatchSpecs(proto, sPort, dPort)...)
|
||||
|
||||
var mangleSpecs []string
|
||||
if !isRoute {
|
||||
mangleSpecs = slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
)
|
||||
}
|
||||
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
|
||||
chain := chainACLInput
|
||||
if isRoute {
|
||||
chain = chainRTFwdIn
|
||||
}
|
||||
|
||||
// Peer ACL drops are inserted at position 1 so they precede the
|
||||
// chain's catch-all; route ACL drops are inserted at position 2
|
||||
// to sit immediately after the established/related accept rule.
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
pos := 1
|
||||
if isRoute {
|
||||
pos = 2
|
||||
}
|
||||
err = r.iptablesClient.Insert(tableFilter, chain, pos, specs...)
|
||||
} else {
|
||||
err = r.iptablesClient.Append(tableFilter, chain, specs...)
|
||||
}
|
||||
if err != nil {
|
||||
r.dropSourceMatch(destExp)
|
||||
return nil, fmt.Errorf("install filter rule on %s: %w", chain, err)
|
||||
}
|
||||
|
||||
// The mangle redirect-mark rule is best effort: the filter rule itself
|
||||
// is what enforces the ACL, so a mangle failure must not undo it. Drop
|
||||
// the spec so teardown does not try to remove a rule that was not added.
|
||||
if mangleSpecs != nil {
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTPre, mangleSpecs...); err != nil {
|
||||
log.Errorf("add mangle rule: %v", err)
|
||||
mangleSpecs = nil
|
||||
}
|
||||
}
|
||||
|
||||
return &Rule{
|
||||
id: ruleID,
|
||||
specs: specs,
|
||||
mangleSpecs: mangleSpecs,
|
||||
chain: chain,
|
||||
v6: r.v6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// applyNetwork resolves a firewall.Network into the iptables match
|
||||
// fragment for the given direction flag (-s or -d). Set networks
|
||||
// increment the shared ipset refcount; prefixes emit a direct match;
|
||||
// an empty network returns no spec ("match any").
|
||||
func (r *family) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||
direction := "src"
|
||||
if flag == "-d" {
|
||||
direction = "dst"
|
||||
}
|
||||
|
||||
if network.IsSet() {
|
||||
name := r.ipsetName(network.Set.HashedName())
|
||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return []string{"-m", "set", matchSet, name, direction}, nil
|
||||
}
|
||||
if network.IsPrefix() {
|
||||
return []string{flag, network.Prefix.String()}, nil
|
||||
}
|
||||
|
||||
// nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||
if v6 && protocol == firewall.ProtocolICMP {
|
||||
return "ipv6-icmp"
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
// filterMatchSpecs returns the proto/port match fragment for a
|
||||
// filtering rule. The source match (-s or -m set) is built by the
|
||||
// caller and prepended.
|
||||
func filterMatchSpecs(protocol string, sPort, dPort *firewall.Port) (specs []string) {
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
specs = append(specs, applyPort("--sport", sPort)...)
|
||||
specs = append(specs, applyPort("--dport", dPort)...)
|
||||
return specs
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
if action == firewall.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
||||
}
|
||||
|
||||
if len(port.Values) > 1 {
|
||||
portList := make([]string, len(port.Values))
|
||||
for i, p := range port.Values {
|
||||
portList[i] = strconv.Itoa(int(p))
|
||||
}
|
||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||
}
|
||||
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
104
client/firewall/iptables/ipset_linux.go
Normal file
104
client/firewall/iptables/ipset_linux.go
Normal file
@@ -0,0 +1,104 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func (r *family) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := r.createIPSet(setName); err != nil {
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||
// The refcounter records nothing when this callback errors,
|
||||
// so destroy the set or it leaks in the kernel. A partial
|
||||
// source set would also fail-open for deny rules, so the
|
||||
// rule must fail rather than install with a missing source.
|
||||
if derr := r.destroyIPSet(setName); derr != nil {
|
||||
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
|
||||
}
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) deleteIpSet(setName string) error {
|
||||
if err := r.destroyIPSet(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
log.Debugf("deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
name := r.ipsetName(set.HashedName())
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) ipsetName(name string) string {
|
||||
if r.v6 {
|
||||
return name + "-v6"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func (r *family) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if r.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
ip := addr.AsSlice()
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: uint8(prefix.Bits()),
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package iptables
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@@ -18,25 +17,21 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type resetter interface {
|
||||
Reset() error
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
// Manager of iptables firewall. Per-family state (peer ACLs, route
|
||||
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
||||
// by family and provides the public firewall.Manager surface.
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
wgIface iFaceMapper
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
family4 *family
|
||||
rawSupported bool
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
ipv6Client *iptables.IPTables
|
||||
aclMgr6 *aclManager
|
||||
router6 *router
|
||||
family6 *family
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -57,14 +52,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouter(iptablesClient, wgIface, mtu)
|
||||
m.family4, err = newFamily(iptablesClient, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
return nil, fmt.Errorf("create family: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
@@ -81,21 +71,18 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("init ip6tables: %w", err)
|
||||
}
|
||||
m.ipv6Client = ip6Client
|
||||
|
||||
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
||||
family6, err := newFamily(ip6Client, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
return fmt.Errorf("create v6 family: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// Share the same IP forwarding state with the v4 family, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
family6.ipFwdState = m.family4.ipFwdState
|
||||
|
||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
m.ipv6Client = ip6Client
|
||||
m.family6 = family6
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -109,7 +96,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
MTU: m.family4.mtu,
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
@@ -141,31 +128,24 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initChains initializes router and ACL chains for both address families,
|
||||
// rolling back on failure.
|
||||
// initChains initializes the per-family firewall state for both
|
||||
// address families, rolling back on failure.
|
||||
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||
type initStep struct {
|
||||
name string
|
||||
init func(*statemanager.Manager) error
|
||||
mgr resetter
|
||||
r *family
|
||||
}
|
||||
|
||||
steps := []initStep{
|
||||
{"router", m.router.init, m.router},
|
||||
{"acl manager", m.aclMgr.init, m.aclMgr},
|
||||
}
|
||||
steps := []initStep{{"v4", m.family4}}
|
||||
if m.hasIPv6() {
|
||||
steps = append(steps,
|
||||
initStep{"v6 router", m.router6.init, m.router6},
|
||||
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
|
||||
)
|
||||
steps = append(steps, initStep{"v6", m.family6})
|
||||
}
|
||||
|
||||
var initialized []initStep
|
||||
for _, s := range steps {
|
||||
if err := s.init(stateManager); err != nil {
|
||||
if err := s.r.init(stateManager); err != nil {
|
||||
for i := len(initialized) - 1; i >= 0; i-- {
|
||||
if rerr := initialized[i].mgr.Reset(); rerr != nil {
|
||||
if rerr := initialized[i].r.Reset(); rerr != nil {
|
||||
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
||||
}
|
||||
}
|
||||
@@ -176,84 +156,50 @@ func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if ip.To4() != nil {
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
// AddFilterRule installs a packet-filtering rule. See firewall.Manager
|
||||
// docs for destination semantics. Sources are a single address family;
|
||||
// the rule is dispatched to the matching v4 / v6 backend.
|
||||
func (m *Manager) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if len(sources) == 0 {
|
||||
return nil, firewall.ErrNoSources
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
fam := m.family4
|
||||
if isIPv6Rule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
fam = m.family6
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteFilterRule removes a rule previously added via AddFilterRule.
|
||||
// The rule is looked up by id in each family's filter cache.
|
||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6IptRule(rule) {
|
||||
return m.aclMgr6.DeletePeerRule(rule)
|
||||
id := rule.ID()
|
||||
if m.family4.hasRule(id) {
|
||||
return m.family4.DeleteFilterRule(rule)
|
||||
}
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6IptRule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.v6
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule.
|
||||
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||
return m.router6.DeleteRouteRule(rule)
|
||||
if m.hasIPv6() && m.family6.hasRule(id) {
|
||||
return m.family6.DeleteFilterRule(rule)
|
||||
}
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
log.Debugf("filter rule %s not found in any family", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
@@ -272,10 +218,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
return m.family6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
if err := m.family4.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -284,7 +230,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -300,18 +246,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
return m.family6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
if err := m.family4.RemoveNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -320,11 +266,11 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -341,19 +287,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclMgr6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
||||
}
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
||||
if err := m.family6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
if err := m.router.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
if err := m.family4.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
||||
}
|
||||
|
||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||
@@ -377,11 +317,11 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
var merr *multierror.Error
|
||||
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -402,14 +342,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -424,9 +364,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddDNATRule(rule)
|
||||
return m.family6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
return m.family4.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
@@ -434,10 +374,10 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
||||
return m.router6.DeleteDNATRule(rule)
|
||||
if m.hasIPv6() && !m.family4.hasDNATRule(rule.ID()) {
|
||||
return m.family6.DeleteDNATRule(rule)
|
||||
}
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
return m.family4.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
@@ -454,12 +394,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -476,9 +416,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
@@ -490,9 +430,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
@@ -504,9 +444,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
@@ -518,14 +458,14 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
chainOutput = "OUTPUT"
|
||||
tableRaw = "raw"
|
||||
)
|
||||
|
||||
@@ -600,15 +540,15 @@ func (m *Manager) initNoTrackChain() error {
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOutput, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPrerouting, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); delErr != nil {
|
||||
log.Debugf("delete output jump rule: %v", delErr)
|
||||
}
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
@@ -635,11 +575,11 @@ func (m *Manager) cleanupNoTrackChain() error {
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPrerouting, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -654,3 +594,13 @@ func (m *Manager) cleanupNoTrackChain() error {
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
// isIPv6Rule reports whether the rule belongs to the IPv6 family, from
|
||||
// the destination prefix when set, otherwise from the (single-family)
|
||||
// sources.
|
||||
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
@@ -65,46 +67,39 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule2 []fw.Rule
|
||||
var rule2 fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
rr := r.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
}
|
||||
rr := rule2.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "udp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||
ok, err := ipv4Client.ChainExists("filter", chainACLInput)
|
||||
require.NoError(t, err, "failed check chain exists")
|
||||
|
||||
if ok {
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainACLInput)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -126,15 +121,13 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{22}}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
|
||||
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add deny rule")
|
||||
require.NotEmpty(t, rule, "deny rule should not be empty")
|
||||
require.NotNil(t, rule, "deny rule should not be nil")
|
||||
|
||||
// Verify the rule was added by checking iptables
|
||||
for _, r := range rule {
|
||||
rr := r.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
}
|
||||
rr := rule.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
})
|
||||
|
||||
t.Run("deny rule precedence test", func(t *testing.T) {
|
||||
@@ -142,36 +135,40 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
|
||||
// Add accept rule first
|
||||
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
|
||||
_, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add accept rule")
|
||||
|
||||
// Add deny rule second for same IP/port - this should take precedence
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add deny rule")
|
||||
|
||||
// Inspect the actual iptables rules to verify deny rule comes before accept rule
|
||||
rules, err := ipv4Client.List("filter", chainNameInputRules)
|
||||
rules, err := ipv4Client.List("filter", chainACLInput)
|
||||
require.NoError(t, err, "failed to list iptables rules")
|
||||
|
||||
// Debug: print all rules
|
||||
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
|
||||
t.Logf("All iptables rules in chain %s:", chainACLInput)
|
||||
for i, rule := range rules {
|
||||
t.Logf(" [%d] %s", i, rule)
|
||||
}
|
||||
|
||||
// Single-source rules emit a direct `-s <ip>/32 ... --dport 80`
|
||||
// match. Match on that shape instead of the legacy
|
||||
// per-(action,port) ipset names ("deny-http"/"accept-http")
|
||||
// that this test predates.
|
||||
srcMatch := fmt.Sprintf("-s %s/32", ip)
|
||||
var denyRuleIndex, acceptRuleIndex = -1, -1
|
||||
for i, rule := range rules {
|
||||
if strings.Contains(rule, "DROP") {
|
||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
|
||||
denyRuleIndex = i
|
||||
}
|
||||
if !strings.Contains(rule, srcMatch) || !strings.Contains(rule, "--dport 80") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(rule, "ACCEPT") {
|
||||
if strings.Contains(rule, "-j DROP") {
|
||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||
denyRuleIndex = i
|
||||
}
|
||||
if strings.Contains(rule, "-j ACCEPT") {
|
||||
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
|
||||
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
|
||||
acceptRuleIndex = i
|
||||
}
|
||||
acceptRuleIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,7 +193,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
@@ -210,27 +206,39 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
var rule2 fw.Rule
|
||||
t.Run("single source uses direct -s match (no ipset)", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||
}
|
||||
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.NotNil(t, rule2)
|
||||
require.Contains(t, rule2.(*Rule).specs, "-s",
|
||||
"single-source rule should use direct -s match, not an ipset")
|
||||
require.Empty(t, findSets(rule2.(*Rule).specs),
|
||||
"single-source rule should not allocate a shared ipset")
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
t.Run("delete single-source rule", func(t *testing.T) {
|
||||
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
|
||||
})
|
||||
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
t.Run("multi-source uses shared ipset", func(t *testing.T) {
|
||||
sources := []netip.Prefix{
|
||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.3"), 32),
|
||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.4"), 32),
|
||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.5"), 32),
|
||||
}
|
||||
port := &fw.Port{Values: []uint16{8080}}
|
||||
multi, err := manager.AddFilterRule(nil, sources, fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add multi-source rule")
|
||||
require.NotNil(t, multi, "multi-source rule must produce one iptables rule")
|
||||
sets := findSets(multi.(*Rule).specs)
|
||||
require.Len(t, sets, 1, "multi-source rule must reference exactly one ipset")
|
||||
|
||||
require.NoError(t, manager.DeleteFilterRule(multi))
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
@@ -281,7 +289,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build integration && !android
|
||||
|
||||
package iptables
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -52,12 +52,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
// 11. MSS clamping rule for outbound traffic
|
||||
require.Len(t, manager.rules, 11, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPostrouting, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPostrouting)
|
||||
require.True(t, exists, "postrouting jump rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
||||
exists, err = manager.iptablesClient.Exists(tableMangle, chainPrerouting, "-j", chainRTPre)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPrerouting)
|
||||
require.True(t, exists, "prerouting jump rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
@@ -84,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -95,7 +95,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "marking rule should be inserted")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -106,8 +106,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "marking rule should be created")
|
||||
foundRule, found := manager.rules[natRuleKey]
|
||||
@@ -121,7 +121,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
// Check inverse rule
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -132,8 +132,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "inverse marking rule should be created")
|
||||
foundRule, found := manager.rules[inverseRuleKey]
|
||||
@@ -157,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
@@ -170,7 +170,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -181,8 +181,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||
require.False(t, exists, "marking rule should not exist")
|
||||
|
||||
_, found := manager.rules[natRuleKey]
|
||||
@@ -190,7 +190,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
// Check inverse rule removal
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -201,8 +201,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||
require.False(t, exists, "inverse marking rule should not exist")
|
||||
|
||||
_, found = manager.rules[inverseRuleKey]
|
||||
@@ -219,13 +219,13 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
r, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family manager")
|
||||
require.NoError(t, r.init(nil))
|
||||
|
||||
defer func() {
|
||||
err := r.Reset()
|
||||
require.NoError(t, err, "Failed to reset router")
|
||||
require.NoError(t, err, "Failed to reset family")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
@@ -334,62 +334,30 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddFilterRule failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
stored, ok := r.filters[ruleKey.ID()]
|
||||
require.True(t, ok, "rule not stored in filters")
|
||||
t.Logf("Internal rule: %v", stored.specs)
|
||||
|
||||
// Log the internal rule
|
||||
t.Logf("Internal rule: %v", rule)
|
||||
|
||||
// Check if the rule exists in iptables
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFwdIn, stored.specs...)
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
var source firewall.Network
|
||||
if len(tt.sources) > 1 {
|
||||
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||
} else if len(tt.sources) > 0 {
|
||||
source.Prefix = tt.sources[0]
|
||||
}
|
||||
// Verify rule content
|
||||
params := routeFilteringRuleParams{
|
||||
Source: source,
|
||||
Destination: firewall.Network{Prefix: tt.destination},
|
||||
Proto: tt.proto,
|
||||
SPort: tt.sPort,
|
||||
DPort: tt.dPort,
|
||||
Action: tt.action,
|
||||
}
|
||||
|
||||
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||
|
||||
// Check if the set was created
|
||||
_, exists := r.ipsetCounter.Get(setName)
|
||||
assert.True(t, exists, "IPSet not created")
|
||||
assert.NotEmpty(t, findSets(stored.specs), "Rule should reference an ipset")
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSetNameInRule(t *testing.T) {
|
||||
r := &router{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
rule []string
|
||||
@@ -430,7 +398,7 @@ func TestFindSetNameInRule(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := r.findSets(tc.rule)
|
||||
result := findSets(tc.rule)
|
||||
|
||||
if len(result) != len(tc.expected) {
|
||||
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||
|
||||
269
client/firewall/iptables/routing_linux.go
Normal file
269
client/firewall/iptables/routing_linux.go
Normal file
@@ -0,0 +1,269 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", "ACCEPT"}
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFwdIn, rule...); err != nil {
|
||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the current legacy management mode
|
||||
func (r *family) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *family) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *family) RemoveAllLegacyRouteRules() error {
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
"!", "-o", "lo",
|
||||
"-j", "MASQUERADE",
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
||||
return fmt.Errorf("add outbound masquerade rule: %w", err)
|
||||
}
|
||||
r.rules["static-nat-outbound"] = rule1
|
||||
|
||||
// Second rule for return traffic masquerade
|
||||
rule2 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
"-o", r.wgIface.Name(),
|
||||
"-j", "MASQUERADE",
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
||||
return fmt.Errorf("add return masquerade rule: %w", err)
|
||||
}
|
||||
r.rules["static-nat-return"] = rule2
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
func (r *family) addMSSClampingRules() error {
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.v6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||
jumpRule := []string{
|
||||
"-j", chainRTMSSClamp,
|
||||
}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainForward, 1, jumpRule...); err != nil {
|
||||
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
||||
}
|
||||
r.rules[jumpMSSClamp] = jumpRule
|
||||
|
||||
ruleOut := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mss),
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTMSSClamp, ruleOut...); err != nil {
|
||||
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
||||
}
|
||||
r.rules["mss-clamp-out"] = ruleOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
ruleID := firewall.RuleID("established-" + chain)
|
||||
r.rules[ruleID] = establishedRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.NatFormat)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
|
||||
return fmt.Errorf("remove existing marking rule for %s: %w", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
markValue := nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
rule := []string{"-i", r.wgIface.Name()}
|
||||
if pair.Inverse {
|
||||
rule = []string{"!", "-i", r.wgIface.Name()}
|
||||
}
|
||||
|
||||
rule = append(rule,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
)
|
||||
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -s: %w", err)
|
||||
}
|
||||
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
|
||||
rule = append(rule, sourceExp...)
|
||||
rule = append(rule, destExp...)
|
||||
rule = append(rule,
|
||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||
)
|
||||
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainRTPre, 1, rule...); err != nil {
|
||||
r.dropSourceMatch(rule)
|
||||
return fmt.Errorf("add marking rule for %s: %w", pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.NatFormat)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
|
||||
return fmt.Errorf("remove marking rule for %s: %w", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("marking rule %s not found", ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
@@ -1,18 +1,20 @@
|
||||
package iptables
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
ruleID string
|
||||
ipsetName string
|
||||
import "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
|
||||
// Rule to handle management of rules. Source set membership (when the
|
||||
// rule was built against a shared hash:net ipset) is encoded in specs;
|
||||
// DeleteFilterRule recovers it via findSets so the refcounter can drop
|
||||
// the right reference.
|
||||
type Rule struct {
|
||||
id manager.RuleID
|
||||
specs []string
|
||||
mangleSpecs []string
|
||||
ip string
|
||||
chain string
|
||||
v6 bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) ID() string {
|
||||
return r.ruleID
|
||||
// ID returns the rule id
|
||||
func (r *Rule) ID() manager.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
}
|
||||
|
||||
func newIpList(ip string) *ipList {
|
||||
ips := make(map[string]struct{})
|
||||
ips[ip] = struct{}{}
|
||||
|
||||
return &ipList{
|
||||
ips: ips,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{
|
||||
IPs: s.ips,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ips = temp.IPs
|
||||
|
||||
if temp.IPs == nil {
|
||||
temp.IPs = make(map[string]struct{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsets map[string]*ipList
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsets: make(map[string]*ipList),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
|
||||
s.ipsets[ipsetName] = list
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipsetNames() []string {
|
||||
names := make([]string, 0, len(s.ipsets))
|
||||
for name := range s.ipsets {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{
|
||||
IPSets: s.ipsets,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ipsets = temp.IPSets
|
||||
|
||||
if temp.IPSets == nil {
|
||||
temp.IPSets = make(map[string]*ipList)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -29,17 +29,13 @@ type ShutdownState struct {
|
||||
|
||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||
|
||||
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||
|
||||
// IPv6 counterparts
|
||||
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
||||
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -57,17 +53,14 @@ func (s *ShutdownState) Cleanup() error {
|
||||
}
|
||||
|
||||
if s.RouteRules != nil {
|
||||
ipt.router.rules = s.RouteRules
|
||||
ipt.family4.rules = s.RouteRules
|
||||
}
|
||||
if s.RouteIPsetCounter != nil {
|
||||
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||
ipt.family4.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||
}
|
||||
|
||||
if s.ACLEntries != nil {
|
||||
ipt.aclMgr.entries = s.ACLEntries
|
||||
}
|
||||
if s.ACLIPsetStore != nil {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
ipt.family4.entries = s.ACLEntries
|
||||
}
|
||||
|
||||
// Clean up v6 state even if the current run has no IPv6.
|
||||
@@ -79,16 +72,13 @@ func (s *ShutdownState) Cleanup() error {
|
||||
}
|
||||
if ipt.hasIPv6() {
|
||||
if s.RouteRules6 != nil {
|
||||
ipt.router6.rules = s.RouteRules6
|
||||
ipt.family6.rules = s.RouteRules6
|
||||
}
|
||||
if s.RouteIPsetCounter6 != nil {
|
||||
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||
ipt.family6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||
}
|
||||
if s.ACLEntries6 != nil {
|
||||
ipt.aclMgr6.entries = s.ACLEntries6
|
||||
}
|
||||
if s.ACLIPsetStore6 != nil {
|
||||
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
||||
ipt.family6.entries = s.ACLEntries6
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
27
client/firewall/iptables/testhelpers_linux_test.go
Normal file
27
client/firewall/iptables/testhelpers_linux_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func pfx(ip net.IP) []netip.Prefix {
|
||||
if ip == nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
a, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
|
||||
}
|
||||
a = a.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package manager
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
|
||||
@@ -16,6 +15,12 @@ import (
|
||||
// method but the IPv6 firewall components were not initialized.
|
||||
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
|
||||
|
||||
// ErrNoSources is returned when AddFilterRule is called with an empty
|
||||
// source list. "Match any source" must be expressed explicitly with a
|
||||
// /0 prefix; an empty list is a caller error and is rejected rather
|
||||
// than silently widening the rule to every source.
|
||||
var ErrNoSources = errors.New("rule has no sources")
|
||||
|
||||
const (
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
@@ -23,13 +28,18 @@ const (
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
)
|
||||
|
||||
// RuleID identifies a firewall rule. It is a typed string so the
|
||||
// compiler catches accidental mixing with arbitrary string keys. It is
|
||||
// only an identifier and does not implement Rule.
|
||||
type RuleID string
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
//
|
||||
// Each firewall type for different OS can use different type
|
||||
// of the properties to hold data of the created rule
|
||||
type Rule interface {
|
||||
// ID returns the rule id
|
||||
ID() string
|
||||
ID() RuleID
|
||||
}
|
||||
|
||||
// RuleDirection is the traffic direction which a rule is applied
|
||||
@@ -91,6 +101,13 @@ func (d Network) IsPrefix() bool {
|
||||
return d.Prefix.IsValid()
|
||||
}
|
||||
|
||||
// IsZero returns true if the network designates no destination, i.e. it
|
||||
// is the zero value. A zero Network is the peer-rule sentinel; a non-zero
|
||||
// one carries a prefix or set destination.
|
||||
func (d Network) IsZero() bool {
|
||||
return !d.IsPrefix() && !d.IsSet()
|
||||
}
|
||||
|
||||
// Manager is the high level abstraction of a firewall manager
|
||||
//
|
||||
// It declares methods which handle actions required by the
|
||||
@@ -101,43 +118,42 @@ type Manager interface {
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
// AddFilterRule adds a packet-filtering rule to the firewall.
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
// If destination is the zero Network, the rule applies to traffic
|
||||
// inbound to this node, i.e. peer ACL semantics, installed in
|
||||
// the kernel's input chain. If destination is set (prefix or
|
||||
// set), the rule applies to forwarded traffic with that
|
||||
// destination, route ACL semantics, installed in the forward
|
||||
// chain.
|
||||
//
|
||||
// Note: Callers should call Flush() after adding rules to ensure
|
||||
// they are applied to the kernel and rule handles are refreshed.
|
||||
AddPeerFiltering(
|
||||
// sources must be a single address family; the caller splits mixed
|
||||
// families and calls once per family. "Match any source" must be
|
||||
// expressed with an explicit /0 prefix; an empty sources list is
|
||||
// rejected with ErrNoSources so a zeroed list can never widen a
|
||||
// rule to every source.
|
||||
//
|
||||
// Note: callers should call Flush() after adding rules.
|
||||
AddFilterRule(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
sources []netip.Prefix,
|
||||
destination Network,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
) ([]Rule, error)
|
||||
) (Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
DeletePeerRule(rule Rule) error
|
||||
// DeleteFilterRule removes a filtering rule previously added via
|
||||
// AddFilterRule. The rule's own type identifies whether it lives
|
||||
// in the peer (input) or route (forward) path.
|
||||
DeleteFilterRule(rule Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
IsStateful() bool
|
||||
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination Network,
|
||||
proto Protocol,
|
||||
sPort, dPort *Port,
|
||||
action Action,
|
||||
) (Rule, error)
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
|
||||
// AddNatRule inserts a routing NAT rule
|
||||
AddNatRule(pair RouterPair) error
|
||||
|
||||
@@ -185,8 +201,9 @@ type Manager interface {
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||
// GenKey builds the rule id for this pair from the given format.
|
||||
func (p RouterPair) GenKey(format string) RuleID {
|
||||
return RuleID(fmt.Sprintf(format, p.ID, p.Inverse))
|
||||
}
|
||||
|
||||
// LegacyManager defines the interface for legacy management operations
|
||||
@@ -242,6 +259,20 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
return merged
|
||||
}
|
||||
|
||||
// UnmapPrefix normalizes a v4-mapped v6 prefix (::ffff:a.b.c.d) to its
|
||||
// plain v4 form, shifting the prefix length out of the 96-bit mapped
|
||||
// range. Other prefixes are returned unchanged. Keeping prefixes
|
||||
// unmapped ensures v4 rules match consistently and the match builders
|
||||
// read the correct address length.
|
||||
func UnmapPrefix(p netip.Prefix) netip.Prefix {
|
||||
addr := p.Addr()
|
||||
if !addr.Is4In6() {
|
||||
return p
|
||||
}
|
||||
bits := max(p.Bits()-96, 0)
|
||||
return netip.PrefixFrom(addr.Unmap(), bits)
|
||||
}
|
||||
|
||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func SortPrefixes(prefixes []netip.Prefix) {
|
||||
|
||||
@@ -13,13 +13,13 @@ type ForwardRule struct {
|
||||
TranslatedPort Port
|
||||
}
|
||||
|
||||
func (r ForwardRule) ID() string {
|
||||
func (r ForwardRule) ID() RuleID {
|
||||
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||
r.Protocol,
|
||||
r.DestinationPort.String(),
|
||||
r.TranslatedAddress.String(),
|
||||
r.TranslatedPort.String())
|
||||
return id
|
||||
return RuleID(id)
|
||||
}
|
||||
|
||||
func (r ForwardRule) String() string {
|
||||
|
||||
@@ -40,7 +40,7 @@ func (h Set) Comment() string {
|
||||
|
||||
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||
// sort for consistent naming
|
||||
prefixes = slices.Clone(prefixes)
|
||||
SortPrefixes(prefixes)
|
||||
|
||||
hash := sha256.New()
|
||||
|
||||
@@ -1,713 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
)
|
||||
|
||||
const flushError = "flush: %w"
|
||||
|
||||
type AclManager struct {
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
af addrFamily
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
}
|
||||
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for both type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create nf conn: %w", err)
|
||||
}
|
||||
|
||||
return &AclManager{
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routingFwChainName: routingFwChainName,
|
||||
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
||||
|
||||
ipsetStore: newIpsetStore(),
|
||||
rules: make(map[string]*Rule),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) init(workTable *nftables.Table) error {
|
||||
m.workTable = workTable
|
||||
return m.createDefaultChains()
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *AclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var ipset *nftables.Set
|
||||
if ipsetName != "" {
|
||||
var err error
|
||||
ipset, err = m.addIpToSet(ipsetName, ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRules = append(newRules, ioRule)
|
||||
return newRules, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if r.nftSet == nil {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.ID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
||||
if !ok {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.ID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
if _, ok := ips[r.ip.String()]; ok {
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
log.Debugf("flush error of set delete element, %s", r.nftSet.Name)
|
||||
return err
|
||||
}
|
||||
m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(ips) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(m.rules, r.ID())
|
||||
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
||||
|
||||
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.rConn.FlushSet(r.nftSet)
|
||||
m.rConn.DelSet(r.nftSet)
|
||||
m.ipsetStore.deleteIpset(r.nftSet.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||
func (m *AclManager) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainInputRules,
|
||||
Position: 0,
|
||||
Exprs: expIn,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (m *AclManager) Flush() error {
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addIOFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
nftRule: r.nftRule,
|
||||
mangleRule: r.mangleRule,
|
||||
nftSet: r.nftSet,
|
||||
ruleID: r.ruleID,
|
||||
ip: ip,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var expressions []expr.Any
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: m.af.protoOffset,
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
protoData, err := m.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||
}
|
||||
|
||||
expressions = append(expressions, &expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: []byte{protoData},
|
||||
})
|
||||
}
|
||||
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: m.af.srcAddrOffset,
|
||||
Len: m.af.addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
if ipset == nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
expressions = append(expressions,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipset.Name,
|
||||
SetID: ipset.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
expressions = append(expressions, applyPort(sPort, true)...)
|
||||
expressions = append(expressions, applyPort(dPort, false)...)
|
||||
|
||||
mainExpressions := slices.Clone(expressions)
|
||||
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
case firewall.ActionDrop:
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(ruleId)
|
||||
|
||||
chain := m.chainInputRules
|
||||
rule := &nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: mainExpressions,
|
||||
UserData: userData,
|
||||
}
|
||||
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
var nftRule *nftables.Rule
|
||||
if action == firewall.ActionDrop {
|
||||
nftRule = m.rConn.InsertRule(rule)
|
||||
} else {
|
||||
nftRule = m.rConn.AddRule(rule)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
|
||||
}
|
||||
|
||||
ruleStruct := &Rule{
|
||||
nftRule: nftRule,
|
||||
// best effort mangle rule
|
||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
ip: ip,
|
||||
}
|
||||
m.rules[ruleId] = ruleStruct
|
||||
if ipset != nil {
|
||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||
}
|
||||
|
||||
return ruleStruct, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||
if m.chainPrerouting == nil {
|
||||
log.Warn("prerouting chain is not created")
|
||||
return nil
|
||||
}
|
||||
|
||||
preroutingExprs := slices.Clone(expressions)
|
||||
|
||||
// interface
|
||||
preroutingExprs = append([]expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}, preroutingExprs...)
|
||||
|
||||
// local destination and mark
|
||||
preroutingExprs = append(preroutingExprs,
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
nfRule := m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nfRule
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := m.createChain(chainNameInputRules)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
m.chainInputRules = chain
|
||||
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
// Chain is created by route manager
|
||||
// TODO: move creation to a common place
|
||||
m.chainPrerouting = &nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
}
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.routingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: m.workTable,
|
||||
}
|
||||
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
insertReturnTrafficRule(m.rConn, m.workTable, chain)
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: m.workTable,
|
||||
Hooknum: hookNum,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
return m.rConn.AddChain(chain)
|
||||
}
|
||||
|
||||
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: to,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
||||
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
}
|
||||
|
||||
m.ipsetStore.newIpset(ipset.Name)
|
||||
}
|
||||
|
||||
if m.ipsetStore.IsIpInSet(ipset.Name, ip) {
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||
}
|
||||
|
||||
m.ipsetStore.AddIpToSet(ipset.Name, ip)
|
||||
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
// createSet in given table by name
|
||||
func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) {
|
||||
ipset := &nftables.Set{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: m.af.setKeyType,
|
||||
}
|
||||
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
return nil, fmt.Errorf("create set: %v", err)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush created set: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to flush nftables: %v", err)
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime *= 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||
if m.workTable == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := m.rConn.GetRules(m.workTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) == 0 {
|
||||
continue
|
||||
}
|
||||
split := bytes.Split(rule.UserData, []byte(" "))
|
||||
r, ok := m.rules[string(split[0])]
|
||||
if ok {
|
||||
if mangle {
|
||||
*r.mangleRule = *rule
|
||||
} else {
|
||||
*r.nftRule = *rule
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":" + string(proto) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
if dPort != nil {
|
||||
rulesetID += dPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
rulesetID += strconv.Itoa(int(action))
|
||||
if ipset == nil {
|
||||
return "ip:" + ip.String() + rulesetID
|
||||
}
|
||||
return "set:" + ipset.Name + rulesetID
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
return b
|
||||
}
|
||||
|
||||
|
||||
// ipToBytes converts net.IP to the correct byte length for the address family.
|
||||
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
||||
if af.addrLen == 4 {
|
||||
return ip.To4()
|
||||
}
|
||||
return ip.To16()
|
||||
}
|
||||
|
||||
890
client/firewall/nftables/chains_linux.go
Normal file
890
client/firewall/nftables/chains_linux.go
Normal file
@@ -0,0 +1,890 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) createContainers() error {
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: &prio,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingRdr,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePostrouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameMangleForward,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
r.addPostroutingRules()
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *family) setupDataPlaneMark() error {
|
||||
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||
return errors.New("no mangle chains found")
|
||||
}
|
||||
|
||||
ctNew := getCtNewExprs()
|
||||
preExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
preExprs = append(preExprs, ctNew...)
|
||||
preExprs = append(preExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
preNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: preExprs,
|
||||
}
|
||||
r.conn.AddRule(preNftRule)
|
||||
|
||||
postExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
postExprs = append(postExprs, ctNew...)
|
||||
postExprs = append(postExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
postNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePostrouting],
|
||||
Exprs: postExprs,
|
||||
}
|
||||
r.conn.AddRule(postNftRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) acceptForwardRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.acceptFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) acceptFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fw := "iptables"
|
||||
|
||||
defer func() {
|
||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available.
|
||||
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
|
||||
if err := r.acceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables forward rule: %v", rule)
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables input rule: %v", inputRule)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) getAcceptForwardRules() [][]string {
|
||||
intf := r.wgIface.Name()
|
||||
return [][]string{
|
||||
{"-i", intf, "-j", "ACCEPT"},
|
||||
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) getAcceptInputRule() []string {
|
||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||
// This is used when iptables is not available.
|
||||
func (r *family) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
forwardChain := &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertForwardAcceptRules(forwardChain, intf)
|
||||
|
||||
inputChain := &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertInputAcceptRule(inputChain, intf)
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||
func (r *family) acceptExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
for _, chain := range chains {
|
||||
r.applyExternalChainAccept(chain, intf)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush external chain rules: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
|
||||
if chain.Hooknum == nil {
|
||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||
|
||||
switch *chain.Hooknum {
|
||||
case *nftables.ChainHookForward:
|
||||
r.insertForwardAcceptRules(chain, intf)
|
||||
case *nftables.ChainHookInput:
|
||||
r.insertInputAcceptRule(chain, intf)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||
if err != nil {
|
||||
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
return
|
||||
}
|
||||
r.insertForwardIifRule(chain, intf, existing)
|
||||
r.insertForwardOifEstablishedRule(chain, intf, existing)
|
||||
}
|
||||
|
||||
func (r *family) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||
if existing[userDataAcceptForwardRuleIif] {
|
||||
return
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||
if existing[userDataAcceptForwardRuleOif] {
|
||||
return
|
||||
}
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: append(exprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||
if err != nil {
|
||||
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
return
|
||||
}
|
||||
if existing[userDataAcceptInputRule] {
|
||||
return
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
})
|
||||
}
|
||||
|
||||
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
|
||||
func (r *family) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list rules: %w", err)
|
||||
}
|
||||
present := map[string]bool{}
|
||||
for _, rule := range rules {
|
||||
if !isNetbirdAcceptRuleTag(rule.UserData) {
|
||||
continue
|
||||
}
|
||||
present[string(rule.UserData)] = true
|
||||
}
|
||||
return present, nil
|
||||
}
|
||||
|
||||
func isNetbirdAcceptRuleTag(userData []byte) bool {
|
||||
switch string(userData) {
|
||||
case userDataAcceptForwardRuleIif,
|
||||
userDataAcceptForwardRuleOif,
|
||||
userDataAcceptInputRule:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptFilterRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) removeFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
|
||||
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != table.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
if chain.Name != chainNameForward && chain.Name != chainNameInput {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||
// ensuring cleanup works even after a crash or if chains changed.
|
||||
func (r *family) removeExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, chain := range chains {
|
||||
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove rules from external chain %s/%s: %w", chain.Table.Name, chain.Name, err))
|
||||
continue
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush external chain %s/%s: %w", chain.Table.Name, chain.Name, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||
func (r *family) findExternalChains() []*nftables.Chain {
|
||||
var chains []*nftables.Chain
|
||||
|
||||
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
||||
|
||||
for _, family := range families {
|
||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
log.Debugf("list chains for family %d: %v", family, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, chain := range allChains {
|
||||
if r.isExternalChain(chain) {
|
||||
chains = append(chains, chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chains
|
||||
}
|
||||
|
||||
func (r *family) isExternalChain(chain *nftables.Chain) bool {
|
||||
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||
if chain.Table.Name == firewalldTableName {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
||||
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Type != nftables.ChainTypeFilter {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Hooknum == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||
}
|
||||
|
||||
func isIptablesTable(name string) bool {
|
||||
switch name {
|
||||
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chainInputRules,
|
||||
Position: 0,
|
||||
Exprs: expIn,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (r *family) Flush() error {
|
||||
if err := r.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.refreshRuleHandles(r.chainInputRules, false); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||
if r.chainPrerouting == nil {
|
||||
log.Warn("prerouting chain is not created")
|
||||
return nil
|
||||
}
|
||||
|
||||
preroutingExprs := slices.Clone(expressions)
|
||||
|
||||
// interface
|
||||
preroutingExprs = append([]expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}, preroutingExprs...)
|
||||
|
||||
// local destination and mark
|
||||
preroutingExprs = append(preroutingExprs,
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
nfRule := r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nfRule
|
||||
}
|
||||
|
||||
func (r *family) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := r.createChain(chainNameInputRules)
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
r.chainInputRules = chain
|
||||
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = r.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
r.addJumpRule(chain, r.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||
r.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
chainFwFilter := r.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
r.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||
r.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := r.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (r *family) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
r.chainPrerouting = r.chains[chainNameManglePrerouting]
|
||||
|
||||
r.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: r.routingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) createChain(name string) *nftables.Chain {
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: r.workTable,
|
||||
}
|
||||
|
||||
chain = r.conn.AddChain(chain)
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, chain)
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func (r *family) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: r.workTable,
|
||||
Hooknum: hookNum,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
return r.conn.AddChain(chain)
|
||||
}
|
||||
|
||||
func (r *family) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: to,
|
||||
},
|
||||
}
|
||||
|
||||
_ = r.conn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to flush nftables: %v", err)
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime *= 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *family) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||
if r.workTable == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := r.conn.GetRules(r.workTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) == 0 {
|
||||
continue
|
||||
}
|
||||
pr, ok := r.filters[firewall.RuleID(rule.UserData)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if mangle {
|
||||
if pr.mangleRule != nil {
|
||||
*pr.mangleRule = *rule
|
||||
}
|
||||
} else {
|
||||
*pr.nftRule = *rule
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
550
client/firewall/nftables/dnat_linux.go
Normal file
550
client/firewall/nftables/dnat_linux.go
Normal file
@@ -0,0 +1,550 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/nftables/xt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
ruleID := rule.ID()
|
||||
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(rule.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
// Request forwarding once the rule is about to be installed, releasing
|
||||
// it if a later step fails so the refcount tracks the real rules.
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := r.addDnatRedirect(rule, protoNum, ruleID); err != nil {
|
||||
r.releaseForwarding()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.addDnatMasq(rule, protoNum, ruleID)
|
||||
|
||||
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
|
||||
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
|
||||
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
|
||||
// TODO: find chains with drop policies and add rules there
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
r.releaseForwarding()
|
||||
return nil, fmt.Errorf("flush rules: %w", err)
|
||||
}
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (r *family) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) error {
|
||||
dnatExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
}
|
||||
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
|
||||
|
||||
// shifted translated port is not supported in nftables, so we hand this over to xtables
|
||||
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
|
||||
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
|
||||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
|
||||
return r.addXTablesRedirect(dnatExprs, ruleID, rule)
|
||||
}
|
||||
}
|
||||
|
||||
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dnatExprs = append(dnatExprs, additionalExprs...)
|
||||
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: regProtoMin,
|
||||
RegProtoMax: regProtoMax,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: dnatExprs,
|
||||
UserData: []byte(ruleID + dnatSuffix),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
r.rules[ruleID+dnatSuffix] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
switch {
|
||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||
return r.handlePortRange(rule)
|
||||
case len(rule.TranslatedPort.Values) == 0:
|
||||
return r.handleAddressOnly(rule)
|
||||
case len(rule.TranslatedPort.Values) == 1:
|
||||
return r.handleSinglePort(rule)
|
||||
default:
|
||||
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
|
||||
},
|
||||
}
|
||||
return exprs, 2, 3, nil
|
||||
}
|
||||
|
||||
func (r *family) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
}
|
||||
return exprs, 0, 0, nil
|
||||
}
|
||||
|
||||
func (r *family) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||
},
|
||||
}
|
||||
return exprs, 2, 0, nil
|
||||
}
|
||||
|
||||
func (r *family) addXTablesRedirect(dnatExprs []expr.Any, ruleID firewall.RuleID, rule firewall.ForwardRule) error {
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.Counter{},
|
||||
&expr.Target{
|
||||
Name: "DNAT",
|
||||
Rev: 2,
|
||||
Info: &xt.NatRange2{
|
||||
NatRange: xt.NatRange{
|
||||
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
|
||||
MinIP: rule.TranslatedAddress.AsSlice(),
|
||||
MaxIP: rule.TranslatedAddress.AsSlice(),
|
||||
MinPort: rule.TranslatedPort.Values[0],
|
||||
MaxPort: rule.TranslatedPort.Values[1],
|
||||
},
|
||||
BasePort: rule.DestinationPort.Values[0],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
natTable := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: r.af.tableFamily,
|
||||
}
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: natTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
Table: natTable,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
},
|
||||
Exprs: dnatExprs,
|
||||
UserData: []byte(ruleID + dnatSuffix),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
r.rules[ruleID+dnatSuffix] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) {
|
||||
masqExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: r.af.dstAddrOffset,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
}
|
||||
|
||||
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
|
||||
masqExprs = append(masqExprs, &expr.Masq{})
|
||||
|
||||
masqRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: masqExprs,
|
||||
UserData: []byte(ruleID + snatSuffix),
|
||||
}
|
||||
r.conn.AddRule(masqRule)
|
||||
r.rules[ruleID+snatSuffix] = masqRule
|
||||
}
|
||||
|
||||
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
||||
ruleID := rule.ID()
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
var needsFlush bool
|
||||
var found bool
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||
found = true
|
||||
if dnatRule.Handle == 0 {
|
||||
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleID+dnatSuffix)
|
||||
delete(r.rules, ruleID+dnatSuffix)
|
||||
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if masqRule, exists := r.rules[ruleID+snatSuffix]; exists {
|
||||
found = true
|
||||
if masqRule.Handle == 0 {
|
||||
log.Warnf("snat rule %s has no handle, removing stale entry", ruleID+snatSuffix)
|
||||
delete(r.rules, ruleID+snatSuffix)
|
||||
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsFlush {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
}
|
||||
|
||||
if merr != nil {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
delete(r.rules, ruleID+dnatSuffix)
|
||||
delete(r.rules, ruleID+snatSuffix)
|
||||
|
||||
// Release once, only if the rule was present and removed.
|
||||
if found {
|
||||
r.releaseForwarding()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// releaseForwarding drops one IP forwarding reference, logging any error.
|
||||
func (r *family) releaseForwarding() {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("release IP forwarding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 3,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||
},
|
||||
}
|
||||
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||
func (r *family) ensureNATOutputChain() error {
|
||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameNATOutput,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
delete(r.chains, chainNameNATOutput)
|
||||
return fmt.Errorf("create NAT output chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||
},
|
||||
}
|
||||
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameNATOutput],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
249
client/firewall/nftables/family_linux.go
Normal file
249
client/firewall/nftables/family_linux.go
Normal file
@@ -0,0 +1,249 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
const (
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
tableRaw = "raw"
|
||||
tableSecurity = "security"
|
||||
|
||||
chainNameNatPrerouting = "PREROUTING"
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameNATOutput = "netbird-nat-output"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
// Peer ACL chain names.
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
|
||||
flushError = "flush: %w"
|
||||
|
||||
firewalldTableName = "firewalld"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
|
||||
dnatSuffix firewall.RuleID = "_dnat"
|
||||
snatSuffix firewall.RuleID = "_snat"
|
||||
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
|
||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||
maxPrefixesSet = 1500
|
||||
refreshRulesMapError = "refresh rules map: %w"
|
||||
)
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||
)
|
||||
|
||||
type setInput struct {
|
||||
set firewall.Set
|
||||
prefixes []netip.Prefix
|
||||
}
|
||||
|
||||
// family holds the per-address-family nftables state. One instance
|
||||
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
||||
// single family; the top-level Manager owns one for v4 and another
|
||||
// for v6. The name predates the peer-ACL absorption; it's effectively
|
||||
// the per-family backend now.
|
||||
type family struct {
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
|
||||
// filters holds peer + route filter rules keyed by content hash.
|
||||
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
||||
filters map[firewall.RuleID]*Rule
|
||||
|
||||
// rules holds NAT, DNAT, and external accept rules (auxiliary
|
||||
// plumbing that isn't a filter rule).
|
||||
rules map[firewall.RuleID]*nftables.Rule
|
||||
|
||||
// Peer ACL chain handles.
|
||||
chainInputRules *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
routingFwChainName string
|
||||
|
||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||
|
||||
af addrFamily
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func newFamily(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*family, error) {
|
||||
r := &family{
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
filters: make(map[firewall.RuleID]*Rule),
|
||||
rules: make(map[firewall.RuleID]*nftables.Rule),
|
||||
routingFwChainName: chainNameRoutingFw,
|
||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
r.deleteIpSet,
|
||||
)
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
log.Debugf("ip filter table not found: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *family) init(workTable *nftables.Table) error {
|
||||
r.workTable = workTable
|
||||
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from filter table: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
if err := r.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default acl chains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables filter table rules from the system
|
||||
func (r *family) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func hookName(hook *nftables.ChainHook) string {
|
||||
if hook == nil {
|
||||
return "unknown"
|
||||
}
|
||||
switch *hook {
|
||||
case *nftables.ChainHookForward:
|
||||
return chainNameForward
|
||||
case *nftables.ChainHookInput:
|
||||
return chainNameInput
|
||||
default:
|
||||
return fmt.Sprintf("hook(%d)", *hook)
|
||||
}
|
||||
}
|
||||
|
||||
func familyName(family nftables.TableFamily) string {
|
||||
switch family {
|
||||
case nftables.TableFamilyIPv4:
|
||||
return "ip"
|
||||
case nftables.TableFamilyIPv6:
|
||||
return "ip6"
|
||||
case nftables.TableFamilyINet:
|
||||
return "inet"
|
||||
default:
|
||||
return fmt.Sprintf("family(%d)", family)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) iptablesProto() iptables.Protocol {
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
return iptables.ProtocolIPv6
|
||||
}
|
||||
return iptables.ProtocolIPv4
|
||||
}
|
||||
|
||||
func (r *family) refreshRulesMap() error {
|
||||
var merr *multierror.Error
|
||||
newRules := make(map[firewall.RuleID]*nftables.Rule)
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||
// preserve existing entries for this chain since we can't verify their state
|
||||
for k, v := range r.rules {
|
||||
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||
newRules[k] = v
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
newRules[firewall.RuleID(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
r.rules = newRules
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
441
client/firewall/nftables/filter_linux.go
Normal file
441
client/firewall/nftables/filter_linux.go
Normal file
@@ -0,0 +1,441 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
)
|
||||
|
||||
// AddFilterRule installs one nftables packet-filter rule. With
|
||||
// destination empty the rule goes to the peer ACL input chain plus a
|
||||
// paired prerouting mangle rule for the redirect mark. With
|
||||
// destination set (prefix or named set) it goes to the route ACL
|
||||
// forward chain. Multi-source rules collapse to one nftables rule
|
||||
// backed by the shared refcounted hash:net set.
|
||||
func (r *family) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
isRoute := !destination.IsZero()
|
||||
|
||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||
if existing, ok := r.filters[ruleID]; ok {
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
srcExprs, err := r.applyNetwork(sourceNetwork(sources), sources, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
if isRoute {
|
||||
exprs, err = r.buildRouteFilterExprs(srcExprs, destination, proto, sPort, dPort)
|
||||
} else {
|
||||
exprs, err = r.buildPeerFilterExprs(srcExprs, proto, sPort, dPort)
|
||||
}
|
||||
if err != nil {
|
||||
r.dropNetworkMatch(srcExprs)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mainExprs := slices.Clone(exprs)
|
||||
verdict := expr.VerdictAccept
|
||||
if action == firewall.ActionDrop {
|
||||
verdict = expr.VerdictDrop
|
||||
}
|
||||
mainExprs = append(mainExprs, &expr.Verdict{Kind: verdict})
|
||||
|
||||
chain := r.chainInputRules
|
||||
if isRoute {
|
||||
chain = r.chains[chainNameRoutingFw]
|
||||
}
|
||||
|
||||
userData := []byte(ruleID)
|
||||
nftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: mainExprs,
|
||||
UserData: userData,
|
||||
}
|
||||
if action == firewall.ActionDrop {
|
||||
nftRule = r.conn.InsertRule(nftRule)
|
||||
} else {
|
||||
nftRule = r.conn.AddRule(nftRule)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
r.dropNetworkMatch(exprs)
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
nftRule: nftRule,
|
||||
sources: sources,
|
||||
id: ruleID,
|
||||
}
|
||||
if !isRoute {
|
||||
rule.mangleRule = r.createPreroutingRule(exprs, userData)
|
||||
}
|
||||
r.filters[ruleID] = rule
|
||||
|
||||
log.Debugf("added filter rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v",
|
||||
sources, destination, proto, sPort, dPort, action)
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// buildPeerFilterExprs assembles the input-chain (peer ACL) match: the
|
||||
// IP-header protocol byte read via Payload, then source, then ports
|
||||
// (no counter), matching the historical peer shape so per-rule kernel
|
||||
// state is identical to pre-unification.
|
||||
func (r *family) buildPeerFilterExprs(
|
||||
srcExprs []expr.Any,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
) ([]expr.Any, error) {
|
||||
var exprs []expr.Any
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := r.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: r.af.protoOffset,
|
||||
Len: 1,
|
||||
},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
||||
)
|
||||
}
|
||||
exprs = append(exprs, srcExprs...)
|
||||
exprs = append(exprs, applyPort(sPort, true)...)
|
||||
exprs = append(exprs, applyPort(dPort, false)...)
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
// buildRouteFilterExprs assembles the forward-chain (route ACL) match:
|
||||
// source, then destination, then optional proto/ports, then a counter.
|
||||
func (r *family) buildRouteFilterExprs(
|
||||
srcExprs []expr.Any,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
) ([]expr.Any, error) {
|
||||
exprs := append([]expr.Any{}, srcExprs...)
|
||||
|
||||
destExprs, err := r.applyNetwork(destination, nil, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
exprs = append(exprs, destExprs...)
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := r.af.protoNum(proto)
|
||||
if err != nil {
|
||||
r.dropNetworkMatch(destExprs)
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
||||
)
|
||||
exprs = append(exprs, applyPort(sPort, true)...)
|
||||
exprs = append(exprs, applyPort(dPort, false)...)
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Counter{})
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
func (r *family) hasRule(id firewall.RuleID) bool {
|
||||
_, ok := r.filters[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
||||
_, ok := r.rules[id+dnatSuffix]
|
||||
return ok
|
||||
}
|
||||
|
||||
// DeleteFilterRule removes a previously installed filter rule. Source
|
||||
// set references are recovered from the stored rule's expressions via
|
||||
// findSets and dropped from the shared refcounter.
|
||||
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
||||
ruleID := rule.ID()
|
||||
pr, ok := r.filters[ruleID]
|
||||
if !ok {
|
||||
log.Debugf("filter rule %s not found", ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// A freshly added rule carries no handle until it is read back from
|
||||
// the kernel, and Flush only refreshes the peer chains. Pull live
|
||||
// handles for this rule's chain before deciding it is stale so route
|
||||
// rules (which Flush never refreshes) can actually be deleted.
|
||||
if pr.nftRule.Handle == 0 {
|
||||
if err := r.refreshRuleHandles(pr.nftRule.Chain, false); err != nil {
|
||||
log.Warnf("refresh handles for chain %s: %v", pr.nftRule.Chain.Name, err)
|
||||
}
|
||||
if pr.mangleRule != nil {
|
||||
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
||||
log.Warnf("refresh mangle handles: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pr.nftRule.Handle == 0 {
|
||||
log.Warnf("filter rule %s has no handle, removing stale entry", ruleID)
|
||||
r.dropNetworkMatch(pr.nftRule.Exprs)
|
||||
delete(r.filters, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(pr.nftRule); err != nil {
|
||||
log.Errorf("queue rule delete: %v", err)
|
||||
}
|
||||
if pr.mangleRule != nil {
|
||||
if err := r.conn.DelRule(pr.mangleRule); err != nil {
|
||||
log.Errorf("queue mangle rule delete: %v", err)
|
||||
}
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete %s: %w", ruleID, err)
|
||||
}
|
||||
|
||||
r.dropNetworkMatch(pr.nftRule.Exprs)
|
||||
delete(r.filters, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) decrementSetCounter(rule *nftables.Rule) error {
|
||||
if r.ipsetCounter == nil {
|
||||
return nil
|
||||
}
|
||||
sets := findSets(rule)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, setName := range sets {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// dropNetworkMatch undoes whatever the source/destination match
|
||||
// reserved. Safe to call when the spec is empty or holds only inline
|
||||
// matchers.
|
||||
func (r *family) dropNetworkMatch(exprs []expr.Any) {
|
||||
if r.ipsetCounter == nil {
|
||||
return
|
||||
}
|
||||
for _, e := range exprs {
|
||||
lookup, ok := e.(*expr.Lookup)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, err := r.ipsetCounter.Decrement(lookup.SetName); err != nil {
|
||||
log.Errorf("rollback ipset decrement %s: %v", lookup.SetName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) applyNetwork(
|
||||
network firewall.Network,
|
||||
setPrefixes []netip.Prefix,
|
||||
isSource bool,
|
||||
) ([]expr.Any, error) {
|
||||
if network.IsSet() {
|
||||
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
||||
if err != nil {
|
||||
side := "destination"
|
||||
if isSource {
|
||||
side = "source"
|
||||
}
|
||||
return nil, fmt.Errorf("%s set: %w", side, err)
|
||||
}
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
if network.IsPrefix() {
|
||||
return prefixMatchExprs(r.af, network.Prefix, isSource), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// prefixMatchExprs is the family-aware match sequence for a CIDR
|
||||
// prefix. /0 returns nil; a host prefix (full bit length for the
|
||||
// family) skips the bitwise step since the mask is all-ones. Shared
|
||||
// between family and aclManager so both treat single prefixes
|
||||
// identically.
|
||||
func prefixMatchExprs(af addrFamily, prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
offset := af.dstAddrOffset
|
||||
if isSource {
|
||||
offset = af.srcAddrOffset
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
if ones == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: af.addrLen,
|
||||
}
|
||||
cmp := &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: prefix.Masked().Addr().AsSlice(),
|
||||
}
|
||||
|
||||
if ones == af.totalBits {
|
||||
return []expr.Any{payload, cmp}
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(ones, af.totalBits)
|
||||
xor := make([]byte, af.addrLen)
|
||||
return []expr.Any{
|
||||
payload,
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: af.addrLen,
|
||||
Mask: mask,
|
||||
Xor: xor,
|
||||
},
|
||||
cmp,
|
||||
}
|
||||
}
|
||||
|
||||
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
|
||||
// src
|
||||
offset := uint32(2)
|
||||
if isSource {
|
||||
// dst
|
||||
offset = 0
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: offset,
|
||||
Len: 2,
|
||||
})
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
exprs = append(exprs,
|
||||
&expr.Range{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
for i, p := range port.Values {
|
||||
if i > 0 {
|
||||
exprs = append(exprs, &expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0x00, 0x00, 0xff, 0xff},
|
||||
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
})
|
||||
}
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(p),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return exprs
|
||||
}
|
||||
|
||||
func getCtNewExprs() []expr.Any {
|
||||
return []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
||||
// shape the rest of the spec-builder consumes: empty for match-any, a
|
||||
// single prefix inline, or an ipset for multiple sources.
|
||||
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
||||
switch {
|
||||
case len(sources) == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1:
|
||||
return firewall.Network{Prefix: sources[0]}
|
||||
default:
|
||||
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
||||
}
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
return b
|
||||
}
|
||||
|
||||
// findSets scans an nftables rule's expressions for expr.Lookup and
|
||||
// returns the named sets in occurrence order. Used at delete time to
|
||||
// drop ipsetCounter references; peer and route ACLs go through it.
|
||||
func findSets(rule *nftables.Rule) []string {
|
||||
var sets []string
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
sets = append(sets, lookup.SetName)
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
210
client/firewall/nftables/ipset_linux.go
Normal file
210
client/firewall/nftables/ipset_linux.go
Normal file
@@ -0,0 +1,210 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
func (r *family) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
||||
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
||||
set: set,
|
||||
prefixes: prefixes,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return r.getIpSetExprs(ref, isSource)
|
||||
}
|
||||
|
||||
func (r *family) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
prefixes := firewall.MergeIPRanges(input.prefixes)
|
||||
|
||||
nfset := &nftables.Set{
|
||||
Name: setName,
|
||||
Comment: input.set.Comment(),
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: r.af.setKeyType,
|
||||
}
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
nElements := len(elements)
|
||||
|
||||
maxElements := maxPrefixesSet * 2
|
||||
initialElements := elements[:min(maxElements, nElements)]
|
||||
|
||||
if err := r.conn.AddSet(nfset, initialElements); err != nil {
|
||||
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
|
||||
|
||||
// The set is committed now. If a later batch fails, destroy it: the
|
||||
// refcounter records nothing on a create-callback error, so it would
|
||||
// otherwise leak, and a partial source set fails-open for deny rules.
|
||||
if err := r.addRemainingElements(nfset, elements, maxElements); err != nil {
|
||||
if derr := r.deleteIpSet(setName, nfset); derr != nil {
|
||||
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
// addRemainingElements adds element batches beyond the initial one in
|
||||
// maxElements-sized chunks, flushing each. Called after the set has been
|
||||
// created with its first batch.
|
||||
func (r *family) addRemainingElements(nfset *nftables.Set, elements []nftables.SetElement, maxElements int) error {
|
||||
nElements := len(elements)
|
||||
for subStart := maxElements; subStart < nElements; subStart += maxElements {
|
||||
subEnd := min(subStart+maxElements, nElements)
|
||||
subElement := elements[subStart:subEnd]
|
||||
nSubPrefixes := len(subElement) / 2
|
||||
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
|
||||
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
|
||||
return fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, nfset.Name, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range prefixes {
|
||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||
firstIP := prefix.Addr()
|
||||
|
||||
// For a /0 the last address is the broadcast and its Next() overflows
|
||||
// to an invalid Addr with an empty key, so wrap to the zero address,
|
||||
// which nftables reads as the open end of a full-range interval.
|
||||
var lastKey []byte
|
||||
if prefix.Bits() == 0 {
|
||||
lastKey = make([]byte, r.af.addrLen)
|
||||
} else {
|
||||
lastKey = calculateLastIP(prefix).Next().AsSlice()
|
||||
}
|
||||
|
||||
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
||||
elements = append(elements,
|
||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||
nftables.SetElement{Key: lastKey, IntervalEnd: true},
|
||||
)
|
||||
}
|
||||
return elements
|
||||
}
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||
masked := prefix.Masked()
|
||||
if masked.Addr().Is4() {
|
||||
hostMask := ^uint32(0) >> masked.Bits()
|
||||
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
|
||||
// IPv6: set host bits to all 1s
|
||||
b := masked.Addr().As16()
|
||||
bits := masked.Bits()
|
||||
for i := bits; i < 128; i++ {
|
||||
b[i/8] |= 1 << (7 - i%8)
|
||||
}
|
||||
return netip.AddrFrom16(b)
|
||||
}
|
||||
|
||||
// Utility function to convert netip.Addr to uint32.
|
||||
func uint32FromNetipAddr(addr netip.Addr) uint32 {
|
||||
b := addr.As4()
|
||||
return binary.BigEndian.Uint32(b[:])
|
||||
}
|
||||
|
||||
// Utility function to convert uint32 to a netip-compatible byte slice.
|
||||
func uint32ToBytes(ip uint32) [4]byte {
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], ip)
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *family) deleteIpSet(setName string, nfset *nftables.Set) error {
|
||||
r.conn.DelSet(nfset)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
||||
if err != nil {
|
||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
// Overlapping prefixes (e.g. duplicate resolved addresses) make the
|
||||
// interval set reject the batch, so merge them as createIpSet does.
|
||||
prefixes = firewall.MergeIPRanges(prefixes)
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
|
||||
// Add in batches sized like createIpSet so a large update does not
|
||||
// exceed the netlink message size limit.
|
||||
maxElements := maxPrefixesSet * 2
|
||||
for start := 0; start < len(elements); start += maxElements {
|
||||
end := min(start+maxElements, len(elements))
|
||||
if err := r.conn.SetAddElements(nfset, elements[start:end]); err != nil {
|
||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("updated set %s with %d prefixes", set.HashedName(), len(prefixes))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
// dst offset by default
|
||||
offset := r.af.dstAddrOffset
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = r.af.srcAddrOffset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
36
client/firewall/nftables/ipset_linux_test.go
Normal file
36
client/firewall/nftables/ipset_linux_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConvertPrefixesToSetWildcard verifies that a /0 prefix produces a
|
||||
// usable interval. The last address of a /0 is the broadcast, whose Next()
|
||||
// overflows to an invalid Addr with an empty key; the IntervalEnd must wrap
|
||||
// to the zero address instead so nftables sees a full-range interval.
|
||||
func TestConvertPrefixesToSetWildcard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
af addrFamily
|
||||
prefix string
|
||||
}{
|
||||
{"IPv4 /0", afIPv4, "0.0.0.0/0"},
|
||||
{"IPv6 /0", afIPv6, "::/0"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &family{af: tt.af}
|
||||
elements := r.convertPrefixesToSet([]netip.Prefix{netip.MustParsePrefix(tt.prefix)})
|
||||
|
||||
require.Len(t, elements, 2, "expected start and interval-end element")
|
||||
assert.False(t, elements[0].IntervalEnd, "first element is the interval start")
|
||||
assert.True(t, elements[1].IntervalEnd, "second element is the interval end")
|
||||
assert.Len(t, elements[1].Key, int(tt.af.addrLen), "interval-end key must be a zero address, not empty")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsetReference map[string]int
|
||||
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsetReference: make(map[string]int),
|
||||
ipsets: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
|
||||
s.ipsetReference[ipsetName] = 0
|
||||
ipList := make(map[string]struct{})
|
||||
s.ipsets[ipsetName] = ipList
|
||||
return ipList
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
delete(s.ipsetReference, ipsetName)
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(ipList, ip.String())
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ipList[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, ok = ipList[ip.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
|
||||
s.ipsetReference[ipsetName]++
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
|
||||
r, ok := s.ipsetReference[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if r == 0 {
|
||||
return
|
||||
}
|
||||
s.ipsetReference[ipsetName]--
|
||||
}
|
||||
|
||||
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
|
||||
if _, ok := s.ipsetReference[ipsetName]; !ok {
|
||||
return false
|
||||
}
|
||||
if s.ipsetReference[ipsetName] == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package nftables
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -45,18 +44,17 @@ type iFaceMapper interface {
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
// Manager of nftables firewall. Per-family state (peer ACLs, route
|
||||
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
||||
// by family and provides the public firewall.Manager surface.
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
router6 *router
|
||||
aclManager6 *AclManager
|
||||
family4 *family
|
||||
// IPv6 counterpart, nil when no v6 overlay.
|
||||
family6 *family
|
||||
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
@@ -75,14 +73,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||
m.family4, err = newFamily(workTable, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
return nil, fmt.Errorf("create family: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
@@ -100,26 +93,21 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
|
||||
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
||||
|
||||
var err error
|
||||
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
||||
m.family6, err = newFamily(workTable6, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
return fmt.Errorf("create v6 family: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
m.family6.ipFwdState = m.family4.ipFwdState
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
||||
func (m *Manager) hasIPv6() bool {
|
||||
return m.router6 != nil
|
||||
return m.family6 != nil
|
||||
}
|
||||
|
||||
func (m *Manager) initIPv6() error {
|
||||
@@ -128,12 +116,8 @@ func (m *Manager) initIPv6() error {
|
||||
return fmt.Errorf("create v6 work table: %w", err)
|
||||
}
|
||||
|
||||
if err := m.router6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 acl manager init: %w", err)
|
||||
if err := m.family6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 family init: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -162,13 +146,13 @@ func (m *Manager) reconcileExternalChains() error {
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
if m.router != nil {
|
||||
if err := m.router.acceptExternalChainsRules(); err != nil {
|
||||
if m.family4 != nil {
|
||||
if err := m.family4.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
|
||||
}
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.acceptExternalChainsRules(); err != nil {
|
||||
if err := m.family6.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -187,12 +171,8 @@ func (m *Manager) initFirewall() (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := m.router.init(workTable); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager.init(workTable); err != nil {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
if err := m.family4.init(workTable); err != nil {
|
||||
return fmt.Errorf("family init: %w", err)
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
@@ -220,7 +200,7 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
MTU: m.family4.mtu,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -235,12 +215,12 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||
|
||||
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
||||
func (m *Manager) rollbackInit() {
|
||||
if err := m.router.Reset(); err != nil {
|
||||
log.Warnf("rollback router: %v", err)
|
||||
if err := m.family4.Reset(); err != nil {
|
||||
log.Warnf("rollback family: %v", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
log.Warnf("rollback v6 router: %v", err)
|
||||
if err := m.family6.Reset(); err != nil {
|
||||
log.Warnf("rollback v6 family: %v", err)
|
||||
}
|
||||
}
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
@@ -251,118 +231,77 @@ func (m *Manager) rollbackInit() {
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
// AddFilterRule installs a packet-filtering rule.
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if ip.To4() != nil {
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
// Destination semantics: zero Network → input chain (peer ACL);
|
||||
// set Network → forward chain (route ACL).
|
||||
//
|
||||
// Sources are a single address family; the rule is dispatched to the
|
||||
// matching per-family backend.
|
||||
func (m *Manager) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if len(sources) == 0 {
|
||||
return nil, firewall.ErrNoSources
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
fam := m.family4
|
||||
if isIPv6Rule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
fam = m.family6
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteFilterRule removes a filtering rule. The owning family is found
|
||||
// by id, refreshing from the kernel if the in-memory caches miss so a
|
||||
// stale cache cannot leak the rule. family.DeleteFilterRule is idempotent
|
||||
// when the id is absent.
|
||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6Rule(rule) {
|
||||
return m.aclManager6.DeletePeerRule(rule)
|
||||
}
|
||||
return m.aclManager.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6Rule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
||||
}
|
||||
|
||||
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
||||
// For static routes, the destination prefix determines the family. For dynamic
|
||||
// routes (DomainSet), the sources determine the family since management
|
||||
// duplicates dynamic rules per family.
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
|
||||
// router; the cached maps are normally authoritative, so the kernel is only
|
||||
// consulted when neither map knows about the rule.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
id := rule.ID()
|
||||
r, err := m.routerForRuleID(id, (*router).hasRule)
|
||||
fam, err := m.familyForRuleID(rule.ID(), (*family).hasRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.DeleteRouteRule(rule)
|
||||
return fam.DeleteFilterRule(rule)
|
||||
}
|
||||
|
||||
// routerForRuleID picks the router holding the rule with the given id, using
|
||||
// familyForRuleID picks the family holding the rule with the given id, using
|
||||
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
|
||||
// from the kernel once and re-checks before falling back to the v4 router.
|
||||
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
|
||||
if has(m.router, id) {
|
||||
return m.router, nil
|
||||
// from the kernel once and re-checks before falling back to the v4 family.
|
||||
func (m *Manager) familyForRuleID(id firewall.RuleID, has func(*family, firewall.RuleID) bool) (*family, error) {
|
||||
if has(m.family4, id) {
|
||||
return m.family4, nil
|
||||
}
|
||||
if m.hasIPv6() && has(m.router6, id) {
|
||||
return m.router6, nil
|
||||
if m.hasIPv6() && has(m.family6, id) {
|
||||
return m.family6, nil
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return m.router, nil
|
||||
return m.family4, nil
|
||||
}
|
||||
if err := m.router.refreshRulesMap(); err != nil {
|
||||
if err := m.family4.refreshRulesMap(); err != nil {
|
||||
return nil, fmt.Errorf("refresh v4 rules: %w", err)
|
||||
}
|
||||
if err := m.router6.refreshRulesMap(); err != nil {
|
||||
if err := m.family6.refreshRulesMap(); err != nil {
|
||||
return nil, fmt.Errorf("refresh v6 rules: %w", err)
|
||||
}
|
||||
if has(m.router6, id) && !has(m.router, id) {
|
||||
return m.router6, nil
|
||||
if has(m.family6, id) && !has(m.family4, id) {
|
||||
return m.family6, nil
|
||||
}
|
||||
return m.router, nil
|
||||
return m.family4, nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
@@ -381,10 +320,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
return m.family6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
if err := m.family4.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -396,7 +335,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
// so the eventual cleanup still works.
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -412,18 +351,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
return m.family6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
if err := m.family4.RemoveNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -445,11 +384,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||
if err := m.family4.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create default allow rules: %w", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
||||
if err := m.family6.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create v6 default allow rules: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -466,11 +405,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -484,13 +423,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
|
||||
if err := m.family4.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
|
||||
if err := m.family6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -530,14 +469,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -551,12 +490,12 @@ func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.aclManager.Flush(); err != nil {
|
||||
if err := m.family4.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.Flush(); err != nil {
|
||||
if err := m.family6.Flush(); err != nil {
|
||||
return fmt.Errorf("flush v6 acl: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -577,9 +516,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddDNATRule(rule)
|
||||
return m.family6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
return m.family4.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
@@ -587,7 +526,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
|
||||
r, err := m.familyForRuleID(rule.ID(), (*family).hasDNATRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -608,12 +547,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -630,9 +569,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
@@ -644,9 +583,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
@@ -658,9 +597,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
@@ -672,9 +611,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -903,3 +842,14 @@ func getEstablishedExprs(register uint32) []expr.Any {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// isIPv6Rule reports whether the rule belongs to the v6 table. For a
|
||||
// prefix destination the destination family decides; otherwise the
|
||||
// (single-family) sources do, since management duplicates rules per
|
||||
// family.
|
||||
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
@@ -70,13 +72,13 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 2, "expected 2 rules")
|
||||
@@ -147,15 +149,12 @@ func TestNftablesManager(t *testing.T) {
|
||||
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
|
||||
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
|
||||
|
||||
for _, r := range rule {
|
||||
err = manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rule), "failed to delete rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
@@ -180,47 +179,39 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
// Add accept rule first
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add accept rule")
|
||||
|
||||
// Add deny rule second for the same traffic
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add deny rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
t.Logf("Found %d rules in nftables chain", len(rules))
|
||||
|
||||
// Find the accept and deny rules and verify deny comes before accept
|
||||
// Single-source rules emit a direct payload+cmp on the source IP
|
||||
// (no set lookup). Match by source-IP + port + verdict instead of
|
||||
// the legacy per-(action,port) set names ("deny-http"/"accept-http")
|
||||
// that this test predates.
|
||||
wantSrc := ip.AsSlice()
|
||||
var acceptRuleIndex, denyRuleIndex = -1, -1
|
||||
for i, rule := range rules {
|
||||
hasAcceptHTTPSet := false
|
||||
hasDenyHTTPSet := false
|
||||
hasPort80 := false
|
||||
var hasSrc, hasPort80 bool
|
||||
var action string
|
||||
|
||||
for _, e := range rule.Exprs {
|
||||
// Check for set lookup
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
switch lookup.SetName {
|
||||
case "accept-http":
|
||||
hasAcceptHTTPSet = true
|
||||
case "deny-http":
|
||||
hasDenyHTTPSet = true
|
||||
if cmp, ok := e.(*expr.Cmp); ok && cmp.Op == expr.CmpOpEq {
|
||||
if bytes.Equal(cmp.Data, wantSrc) {
|
||||
hasSrc = true
|
||||
}
|
||||
|
||||
}
|
||||
// Check for port 80
|
||||
if cmp, ok := e.(*expr.Cmp); ok {
|
||||
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
||||
if len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
||||
hasPort80 = true
|
||||
}
|
||||
}
|
||||
// Check for verdict
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
switch verdict.Kind {
|
||||
case expr.VerdictAccept:
|
||||
@@ -231,11 +222,15 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
|
||||
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
|
||||
if !hasSrc || !hasPort80 {
|
||||
continue
|
||||
}
|
||||
switch action {
|
||||
case "ACCEPT":
|
||||
t.Logf("Rule [%d]: src=%s port=80 ACCEPT", i, ip)
|
||||
acceptRuleIndex = i
|
||||
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
|
||||
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
|
||||
case "DROP":
|
||||
t.Logf("Rule [%d]: src=%s port=80 DROP", i, ip)
|
||||
denyRuleIndex = i
|
||||
}
|
||||
}
|
||||
@@ -279,7 +274,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -361,10 +356,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||
@@ -437,10 +432,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("fd00::2")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "add v6 peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||
@@ -550,7 +545,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
||||
}
|
||||
}
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
prefixes,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
@@ -565,7 +560,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
|
||||
func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
@@ -591,9 +586,9 @@ func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{},
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build integration && !android
|
||||
|
||||
package nftables
|
||||
|
||||
@@ -37,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
// need fw manager to init both acl mgr and family for all chains to be present
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -47,7 +47,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
rtr := manager.router
|
||||
rtr := manager.family4
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "pair should be inserted")
|
||||
|
||||
@@ -90,9 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
testRouter := &router{af: afIPv4}
|
||||
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
testRouter := &family{af: afIPv4}
|
||||
sourceExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Source.Prefix, true)
|
||||
destExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Destination.Prefix, false)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
@@ -100,14 +100,14 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
testingExpression = append(testingExpression, sourceExp...)
|
||||
testingExpression = append(testingExpression, destExp...)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
|
||||
found := 0
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNameManglePrerouting {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
// Compare expressions up to the mark setting expressions
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||
found = 1
|
||||
@@ -135,19 +135,19 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
rtr := manager.router
|
||||
rtr := manager.family4
|
||||
|
||||
// First add the NAT rule using the router's method
|
||||
// First add the NAT rule using the family's method
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "should add NAT rule")
|
||||
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -200,11 +200,11 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
defer func(r *router) {
|
||||
defer func(r *family) {
|
||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||
}(r)
|
||||
|
||||
@@ -314,16 +314,16 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddFilterRule failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
||||
})
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
stored, ok := r.filters[id.RuleID(ruleKey.ID())]
|
||||
require.True(t, ok, "Rule not found in filters map")
|
||||
rule := stored.nftRule
|
||||
|
||||
t.Log("Internal rule expressions:")
|
||||
for i, expr := range rule.Exprs {
|
||||
@@ -339,7 +339,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
var nftRule *nftables.Rule
|
||||
for _, rule := range rules {
|
||||
if string(rule.UserData) == ruleKey.ID() {
|
||||
if firewall.RuleID(rule.UserData) == ruleKey.ID() {
|
||||
nftRule = rule
|
||||
break
|
||||
}
|
||||
@@ -367,12 +367,12 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
require.NoError(t, r.Reset(), "Failed to reset family")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
@@ -509,6 +509,42 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNftablesUpdateSetMergesOverlapping verifies that UpdateSet merges
|
||||
// overlapping prefixes before adding them. An interval set rejects
|
||||
// overlapping elements, so without the merge a batch holding a /32 already
|
||||
// covered by a /24, or a duplicate address as DNS resolution can produce,
|
||||
// would fail.
|
||||
func TestNftablesUpdateSetMergesOverlapping(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err, "create work table")
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "reset family")
|
||||
}()
|
||||
|
||||
initial := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
|
||||
set := firewall.NewPrefixSet(initial)
|
||||
|
||||
created, err := r.createIpSet(set.HashedName(), setInput{prefixes: initial})
|
||||
require.NoError(t, err, "create ip set")
|
||||
require.NotNil(t, created)
|
||||
|
||||
overlapping := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
}
|
||||
require.NoError(t, r.UpdateSet(set, overlapping), "UpdateSet must merge overlapping prefixes")
|
||||
}
|
||||
|
||||
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
@@ -518,11 +554,11 @@ func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to create v6 work table")
|
||||
defer deleteWorkTableIPv6()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
require.NoError(t, r.Reset(), "Failed to reset family")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
@@ -861,13 +897,13 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Add a real rule to the kernel
|
||||
ruleKey, err := r.AddRouteFiltering(
|
||||
ruleKey, err := r.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
@@ -878,11 +914,11 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
||||
})
|
||||
|
||||
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||
staleKey := "stale-rule-that-does-not-exist"
|
||||
staleKey := firewall.RuleID("stale-rule-that-does-not-exist")
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
@@ -902,6 +938,55 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||
}
|
||||
|
||||
// TestRouter_DeleteRouteRule_RemovesKernelRule verifies a route filter
|
||||
// rule is actually removed from the kernel on delete. The route chain is
|
||||
// not refreshed by Flush, so the stored rule carries a zero handle;
|
||||
// DeleteFilterRule must pull live handles itself before issuing the
|
||||
// delete or the kernel rule leaks. Regression test for that path.
|
||||
func TestRouter_DeleteRouteRule_RemovesKernelRule(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
ruleKey, err := r.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
firewall.ProtocolTCP,
|
||||
nil,
|
||||
&firewall.Port{Values: []uint16{80}},
|
||||
firewall.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
countKernelRules := func() int {
|
||||
list, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
|
||||
require.NoError(t, err)
|
||||
n := 0
|
||||
for _, rule := range list {
|
||||
if string(rule.UserData) == string(ruleKey.ID()) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
require.Equal(t, 1, countKernelRules(), "rule should be present in the kernel after add")
|
||||
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
||||
assert.Equal(t, 0, countKernelRules(), "rule must be removed from the kernel after delete")
|
||||
assert.NotContains(t, r.filters, ruleKey.ID(), "filters map entry should be cleared")
|
||||
}
|
||||
|
||||
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
@@ -911,24 +996,28 @@ func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Inject a stale entry with Handle=0
|
||||
staleKey := "stale-route-rule"
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
staleKey := id.RuleID("stale-route-rule")
|
||||
staleRule := &Rule{
|
||||
nftRule: &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
},
|
||||
id: staleKey,
|
||||
}
|
||||
r.filters[staleKey] = staleRule
|
||||
|
||||
// DeleteRouteRule should not return an error for stale handles
|
||||
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||
// DeleteFilterRule should not return an error for stale handles
|
||||
err = r.DeleteFilterRule(staleRule)
|
||||
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||
assert.NotContains(t, r.filters, staleKey, "stale entry should be cleaned up")
|
||||
}
|
||||
|
||||
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
@@ -950,7 +1039,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
Masquerade: true,
|
||||
}
|
||||
|
||||
rtr := manager.router
|
||||
rtr := manager.family4
|
||||
|
||||
// First add succeeds
|
||||
err = rtr.AddNatRule(pair)
|
||||
@@ -960,11 +1049,11 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
})
|
||||
|
||||
// Corrupt the handle to simulate stale state
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
natRuleKey := pair.GenKey(firewall.PreroutingFormat)
|
||||
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||
inverseKey := firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat)
|
||||
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
@@ -979,7 +1068,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
|
||||
found := 0
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
found++
|
||||
}
|
||||
}
|
||||
@@ -1010,7 +1099,7 @@ func TestCalculateLastIP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
||||
r := &router{af: afIPv6}
|
||||
r := &family{af: afIPv6}
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("2001:db8::1/128"),
|
||||
|
||||
494
client/firewall/nftables/routing_linux.go
Normal file
494
client/firewall/nftables/routing_linux.go
Normal file
@@ -0,0 +1,494 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
r.rollbackRules(pair)
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
r.rollbackRules(pair)
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
r.rollbackRules(pair)
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
r.rollbackRules(pair)
|
||||
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||
func (r *family) rollbackRules(pair firewall.RouterPair) {
|
||||
keys := []firewall.RuleID{
|
||||
pair.GenKey(firewall.ForwardingFormat),
|
||||
pair.GenKey(firewall.PreroutingFormat),
|
||||
firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat),
|
||||
}
|
||||
for _, key := range keys {
|
||||
rule, ok := r.rules[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||
}
|
||||
delete(r.rules, key)
|
||||
}
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
op := expr.CmpOpEq
|
||||
if pair.Inverse {
|
||||
op = expr.CmpOpNeq
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: op,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
exprs = append(exprs, getCtNewExprs()...)
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
)
|
||||
|
||||
ruleID := pair.GenKey(firewall.PreroutingFormat)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
r.rules[ruleID] = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addPostroutingRules() {
|
||||
// First masquerade rule for traffic coming in from WireGuard interface
|
||||
exprs := []expr.Any{
|
||||
// Match on the first fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs,
|
||||
})
|
||||
|
||||
// Second masquerade rule for traffic going out through WireGuard interface
|
||||
exprs2 := []expr.Any{
|
||||
// Match on the second fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
},
|
||||
|
||||
// Match WireGuard interface
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs2,
|
||||
})
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
func (r *family) addMSSClampingRules() error {
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
if r.mtu <= overhead {
|
||||
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
|
||||
return nil
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
exprsOut := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 13,
|
||||
Len: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 1,
|
||||
Mask: []byte{0x02},
|
||||
Xor: []byte{0x00},
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0x00},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Exthdr{
|
||||
DestRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGt,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Exthdr{
|
||||
SourceRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameMangleForward],
|
||||
Exprs: exprsOut,
|
||||
})
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
exprs = append(exprs,
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
|
||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleID] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.deleteLegacyRuleEntry(ruleID, rule)
|
||||
}
|
||||
|
||||
// deleteLegacyRuleEntry removes one legacy forwarding rule and drops its
|
||||
// ipset references. It also clears stale entries that never got a handle.
|
||||
func (r *family) deleteLegacyRuleEntry(ruleID firewall.RuleID, rule *nftables.Rule) error {
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleID)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s: %w", ruleID, err)
|
||||
}
|
||||
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the route manager's legacy management mode
|
||||
func (r *family) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *family) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *family) RemoveAllLegacyRouteRules() error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.deleteLegacyRuleEntry(k, rule); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) removeNatPreroutingRules() error {
|
||||
table := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: r.af.tableFamily,
|
||||
}
|
||||
chain := &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
}
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from nat table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// Delete rules that have our UserData suffix
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), string(dnatSuffix)) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||
}
|
||||
|
||||
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||
// counters will be off until the next successful removal or refresh cycle.
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.PreroutingFormat)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
log.Debugf("prerouting rule %s not found", ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleID)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,21 +1,26 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
// Rule wraps an installed filter rule (peer or route). Source set
|
||||
// membership is encoded in the rule's expressions; DeleteFilterRule
|
||||
// recovers the set name via findSets so the refcounter can drop the
|
||||
// right reference. mangleRule is set only for peer rules.
|
||||
type Rule struct {
|
||||
nftRule *nftables.Rule
|
||||
mangleRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
ruleID string
|
||||
ip net.IP
|
||||
// sources is the canonical source list this rule was created for.
|
||||
sources []netip.Prefix
|
||||
id manager.RuleID
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) ID() string {
|
||||
return r.ruleID
|
||||
// ID returns the rule id
|
||||
func (r *Rule) ID() manager.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
27
client/firewall/nftables/testhelpers_linux_test.go
Normal file
27
client/firewall/nftables/testhelpers_linux_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func pfx(ip net.IP) []netip.Prefix {
|
||||
if ip == nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
a, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
|
||||
}
|
||||
a = a.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
@@ -72,12 +71,14 @@ const (
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
// peerRules is the canonical list-based storage for peer ACL rules.
|
||||
// Match order is significant: drop rules come before accept rules so
|
||||
// callers should consult the slice in order.
|
||||
type peerRules []*PeerRule
|
||||
|
||||
type RouteRules []*RouteRule
|
||||
type routeRules []*RouteRule
|
||||
|
||||
func (r RouteRules) Sort() {
|
||||
func (r routeRules) Sort() {
|
||||
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
||||
// Deny rules come first
|
||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||
@@ -86,22 +87,44 @@ func (r RouteRules) Sort() {
|
||||
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.id, b.id)
|
||||
return strings.Compare(string(a.id), string(b.id))
|
||||
})
|
||||
}
|
||||
|
||||
// peerRuleSpec carries the parameters that define a peer filter rule,
|
||||
// threaded together through the build path so the builders take a single
|
||||
// argument instead of a long parameter list.
|
||||
type peerRuleSpec struct {
|
||||
mgmtID []byte
|
||||
sources []netip.Prefix
|
||||
ipLayer gopacket.LayerType
|
||||
matchAny bool
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[netip.Addr]RuleSet
|
||||
incomingDenyRules map[netip.Addr]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
// nativeFirewall is the kernel firewall (nftables/iptables) used for
|
||||
// the split case where peer ACLs run in userspace here but routing
|
||||
// stays in the kernel: when the userspace firewall is forced yet the
|
||||
// router keeps using the kernel, route NAT/ACLs and DNAT are
|
||||
// delegated to it. It is nil when no native backend is available.
|
||||
nativeFirewall firewall.Manager
|
||||
mutex sync.RWMutex
|
||||
|
||||
mutex sync.RWMutex
|
||||
incomingDenyRules peerRules
|
||||
incomingAcceptRules peerRules
|
||||
incomingDenyIndex peerRuleIndex
|
||||
incomingAcceptIndex peerRuleIndex
|
||||
peerRulesMap map[nbid.RuleID]*PeerRule
|
||||
|
||||
routeRules routeRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
|
||||
// indicates whether server routes are disabled
|
||||
disableServerRoutes bool
|
||||
@@ -183,24 +206,6 @@ func (d *decoder) decodePacket(data []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func parseCreateEnv() (bool, bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
||||
var err error
|
||||
@@ -231,7 +236,7 @@ func parseCreateEnv() (bool, bool, bool) {
|
||||
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
func Create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
@@ -254,11 +259,8 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
return d
|
||||
},
|
||||
},
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||
incomingDenyRules: make(map[netip.Addr]RuleSet),
|
||||
incomingRules: make(map[netip.Addr]RuleSet),
|
||||
wgIface: iface,
|
||||
nativeFirewall: nativeFirewall,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
stateful: !disableConntrack,
|
||||
@@ -266,6 +268,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
peerRulesMap: make(map[nbid.RuleID]*PeerRule),
|
||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
@@ -320,7 +323,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
v4Rule, err := m.addRouteFiltering(
|
||||
v4Rule, err := m.addRouteRule(
|
||||
nil,
|
||||
sources,
|
||||
firewall.Network{Prefix: wgPrefix},
|
||||
@@ -336,7 +339,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
|
||||
|
||||
if v6Net.IsValid() {
|
||||
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
||||
v6Rule, err := m.addRouteFiltering(
|
||||
v6Rule, err := m.addRouteRule(
|
||||
nil,
|
||||
sources,
|
||||
firewall.Network{Prefix: v6Net},
|
||||
@@ -488,64 +491,136 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
// addPeerRule installs an input-chain rule that matches packets
|
||||
// by source only. Called from AddFilterRule when the caller doesn't
|
||||
// specify a destination. Mixed-family inputs are split: each family
|
||||
// gets its own rule with a family-correct ipLayer so packet decoding
|
||||
// matches what the matcher expects.
|
||||
func (m *Manager) addPeerRule(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
sources []netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
// TODO: fix in upper layers
|
||||
i, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||
}
|
||||
|
||||
i = i.Unmap()
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
mgmtId: id,
|
||||
ip: i,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
drop: action == firewall.ActionDrop,
|
||||
}
|
||||
if i.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
r.matchByIP = false
|
||||
}
|
||||
|
||||
r.sPort = sPort
|
||||
r.dPort = dPort
|
||||
|
||||
r.protoLayer = protoToLayer(proto, r.ipLayer)
|
||||
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
var targetMap map[netip.Addr]RuleSet
|
||||
if r.drop {
|
||||
targetMap = m.incomingDenyRules
|
||||
} else {
|
||||
targetMap = m.incomingRules
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if sourcesMatchAny(sources) {
|
||||
spec := peerRuleSpec{
|
||||
mgmtID: id,
|
||||
sources: sources,
|
||||
ipLayer: layerTypeAll,
|
||||
matchAny: true,
|
||||
proto: proto,
|
||||
sPort: sPort,
|
||||
dPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
return m.addOnePeerRule(spec), nil
|
||||
}
|
||||
|
||||
if _, ok := targetMap[r.ip]; !ok {
|
||||
targetMap[r.ip] = make(RuleSet)
|
||||
// Sources are a single family; normalize v4-mapped prefixes to plain
|
||||
// v4 and pick the matching IP layer.
|
||||
normalized := make([]netip.Prefix, len(sources))
|
||||
ipLayer := layers.LayerTypeIPv4
|
||||
for i, p := range sources {
|
||||
normalized[i] = firewall.UnmapPrefix(p)
|
||||
if normalized[i].Addr().Is6() {
|
||||
ipLayer = layers.LayerTypeIPv6
|
||||
}
|
||||
}
|
||||
targetMap[r.ip][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
spec := peerRuleSpec{
|
||||
mgmtID: id,
|
||||
sources: normalized,
|
||||
ipLayer: ipLayer,
|
||||
matchAny: false,
|
||||
proto: proto,
|
||||
sPort: sPort,
|
||||
dPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
return m.addOnePeerRule(spec), nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
// addOnePeerRule builds and registers a single-family peer rule, or
|
||||
// returns the existing rule when one with the same content key is
|
||||
// already installed. The caller must hold m.mutex. The content key is
|
||||
// the shared GenerateRuleID with an empty destination, so peer
|
||||
// rules dedup the same way route rules and the kernel backends do.
|
||||
//
|
||||
// There is no refcount: a content key is installed once and deleted on
|
||||
// the first DeleteFilterRule for that key. The caller must therefore
|
||||
// key its own tracking by the returned rule id so add and delete stay
|
||||
// balanced per content key; the acl manager does this via
|
||||
// peerRulesPairs. The content key is order-independent, so callers
|
||||
// passing the same sources in any order dedup to one rule.
|
||||
func (m *Manager) addOnePeerRule(spec peerRuleSpec) *PeerRule {
|
||||
ruleID := nbid.GenerateRuleID(spec.sources, firewall.Network{}, spec.proto, spec.sPort, spec.dPort, spec.action)
|
||||
if existing, ok := m.peerRulesMap[ruleID]; ok {
|
||||
return existing
|
||||
}
|
||||
|
||||
rule := m.buildPeerRule(ruleID, spec)
|
||||
m.registerPeerRule(rule)
|
||||
return rule
|
||||
}
|
||||
|
||||
func (m *Manager) buildPeerRule(ruleID nbid.RuleID, spec peerRuleSpec) *PeerRule {
|
||||
r := &PeerRule{
|
||||
id: ruleID,
|
||||
mgmtId: spec.mgmtID,
|
||||
sources: spec.sources,
|
||||
matchAny: spec.matchAny,
|
||||
action: spec.action,
|
||||
srcPort: spec.sPort,
|
||||
dstPort: spec.dPort,
|
||||
}
|
||||
if !spec.matchAny {
|
||||
r.sourceAddrs = make(map[netip.Addr]struct{}, len(spec.sources))
|
||||
for _, p := range spec.sources {
|
||||
if p.Bits() == p.Addr().BitLen() {
|
||||
r.sourceAddrs[p.Addr()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
r.protoLayer = protoToLayer(spec.proto, spec.ipLayer)
|
||||
return r
|
||||
}
|
||||
|
||||
// registerPeerRule records a freshly built peer rule in the matching
|
||||
// slice, index, and dedup map. The caller must hold m.mutex.
|
||||
func (m *Manager) registerPeerRule(r *PeerRule) {
|
||||
if r.action == firewall.ActionDrop {
|
||||
m.incomingDenyRules = append(m.incomingDenyRules, r)
|
||||
m.incomingDenyIndex.add(r)
|
||||
} else {
|
||||
m.incomingAcceptRules = append(m.incomingAcceptRules, r)
|
||||
m.incomingAcceptIndex.add(r)
|
||||
}
|
||||
m.peerRulesMap[r.id] = r
|
||||
}
|
||||
|
||||
// sourcesMatchAny reports whether the source list matches every source,
|
||||
// i.e. contains an explicit /0 prefix. An empty list does not qualify:
|
||||
// AddFilterRule rejects it with ErrNoSources, so "match any" is always
|
||||
// the deliberate /0 case.
|
||||
func sourcesMatchAny(sources []netip.Prefix) bool {
|
||||
for _, p := range sources {
|
||||
if p.Bits() == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AddFilterRule is the unified entry point for both peer (input chain)
|
||||
// and route (forward chain) filtering rules. The destination
|
||||
// distinguishes the two semantics: a zero Network installs an
|
||||
// input-side rule that matches by source only; a set Network installs
|
||||
// a forward-side rule that also matches the destination.
|
||||
func (m *Manager) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
@@ -553,13 +628,37 @@ func (m *Manager) AddRouteFiltering(
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if len(sources) == 0 {
|
||||
return nil, firewall.ErrNoSources
|
||||
}
|
||||
|
||||
if destination.IsZero() {
|
||||
return m.addPeerRule(id, sources, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return m.addRouteRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeleteFilterRule deletes a filtering rule. The rule's underlying type
|
||||
// is used to route to the correct internal path.
|
||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
if r, ok := rule.(*PeerRule); ok {
|
||||
return m.deletePeerRuleLocked(r)
|
||||
}
|
||||
|
||||
// Either our *RouteRule or, under native delegation, a native
|
||||
// route-rule object that implements firewall.Rule but isn't one of
|
||||
// our concrete types. The route path forwards the latter to the
|
||||
// native firewall.
|
||||
return m.deleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) addRouteFiltering(
|
||||
func (m *Manager) addRouteRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
@@ -568,18 +667,17 @@ func (m *Manager) addRouteFiltering(
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
return m.nativeFirewall.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||
if existingRule, ok := m.routeRulesMap[ruleID]; ok {
|
||||
return existingRule, nil
|
||||
}
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: string(ruleKey),
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
@@ -594,72 +692,57 @@ func (m *Manager) addRouteFiltering(
|
||||
|
||||
m.routeRules = append(m.routeRules, &rule)
|
||||
m.routeRules.Sort()
|
||||
m.routeRulesMap[ruleKey] = &rule
|
||||
m.routeRulesMap[ruleID] = &rule
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.deleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
return m.nativeFirewall.DeleteFilterRule(rule)
|
||||
}
|
||||
|
||||
ruleKey := nbid.RuleID(rule.ID())
|
||||
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||
ruleID := rule.ID()
|
||||
trimmed, _, ok := removeRuleByID(m.routeRules, ruleID)
|
||||
if !ok {
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
}
|
||||
|
||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||
return r.id == string(ruleKey)
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||
}
|
||||
|
||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||
delete(m.routeRulesMap, ruleKey)
|
||||
m.routeRules = trimmed
|
||||
delete(m.routeRulesMap, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
// deletePeerRuleLocked removes a peer rule from the matching slice,
|
||||
// index, and dedup map. The caller must hold m.mutex.
|
||||
func (m *Manager) deletePeerRuleLocked(r *PeerRule) error {
|
||||
target, index := &m.incomingAcceptRules, &m.incomingAcceptIndex
|
||||
if r.action == firewall.ActionDrop {
|
||||
target, index = &m.incomingDenyRules, &m.incomingDenyIndex
|
||||
}
|
||||
|
||||
r, ok := rule.(*PeerRule)
|
||||
trimmed, stored, ok := removeRuleByID(*target, r.id)
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
var sourceMap map[netip.Addr]RuleSet
|
||||
if r.drop {
|
||||
sourceMap = m.incomingDenyRules
|
||||
} else {
|
||||
sourceMap = m.incomingRules
|
||||
}
|
||||
|
||||
if ruleset, ok := sourceMap[r.ip]; ok {
|
||||
if _, exists := ruleset[r.id]; !exists {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(ruleset, r.id)
|
||||
if len(ruleset) == 0 {
|
||||
delete(sourceMap, r.ip)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
|
||||
*target = trimmed
|
||||
index.remove(stored)
|
||||
delete(m.peerRulesMap, r.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRuleByID removes the first rule whose id matches ruleID from
|
||||
// rules, preserving order. It returns the trimmed slice, the removed
|
||||
// rule, and whether a match was found.
|
||||
func removeRuleByID[S ~[]T, T firewall.Rule](rules S, ruleID firewall.RuleID) (S, T, bool) {
|
||||
idx := slices.IndexFunc(rules, func(r T) bool { return r.ID() == ruleID })
|
||||
var removed T
|
||||
if idx < 0 {
|
||||
return rules, removed, false
|
||||
}
|
||||
removed = rules[idx]
|
||||
return slices.Delete(rules, idx, idx+1), removed, true
|
||||
}
|
||||
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if m.nativeFirewall == nil {
|
||||
@@ -674,9 +757,11 @@ func (m *Manager) Flush() error { return nil }
|
||||
// resetState clears all firewall rules and closes connection trackers.
|
||||
// Must be called with m.mutex held.
|
||||
func (m *Manager) resetState() {
|
||||
clear(m.outgoingRules)
|
||||
clear(m.incomingDenyRules)
|
||||
clear(m.incomingRules)
|
||||
m.incomingDenyRules = m.incomingDenyRules[:0]
|
||||
m.incomingAcceptRules = m.incomingAcceptRules[:0]
|
||||
m.incomingDenyIndex.reset()
|
||||
m.incomingAcceptIndex.reset()
|
||||
clear(m.peerRulesMap)
|
||||
clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
@@ -820,11 +905,11 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||
case layers.LayerTypeIPv4:
|
||||
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
return src, dst
|
||||
return src.Unmap(), dst.Unmap()
|
||||
case layers.LayerTypeIPv6:
|
||||
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||
return src, dst
|
||||
return src.Unmap(), dst.Unmap()
|
||||
default:
|
||||
return netip.Addr{}, netip.Addr{}
|
||||
}
|
||||
@@ -1404,20 +1489,12 @@ func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
|
||||
if mgmtId, filter, ok := m.incomingDenyIndex.match(srcIP, d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
|
||||
if mgmtId, filter, ok := m.incomingAcceptIndex.match(srcIP, d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
return nil, true
|
||||
}
|
||||
|
||||
@@ -1438,39 +1515,6 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
|
||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
|
||||
@@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: false,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@@ -114,15 +114,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Add explicit rules matching return traffic pattern
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(
|
||||
_, err := m.AddFilterRule(
|
||||
nil,
|
||||
ip,
|
||||
pfx(ip), fw.Network{},
|
||||
fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@@ -133,15 +131,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: true,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(
|
||||
_, err := m.AddFilterRule(
|
||||
nil,
|
||||
net.ParseIP("0.0.0.0"),
|
||||
pfx(net.ParseIP("0.0.0.0")), fw.Network{},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
)
|
||||
fw.ActionDrop)
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@@ -170,7 +166,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -210,7 +206,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -253,7 +249,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -411,7 +407,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -538,7 +534,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -546,7 +542,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -621,7 +617,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -629,7 +625,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -732,14 +728,14 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -812,13 +808,13 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -931,7 +927,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
|
||||
for _, r := range rules {
|
||||
dst := fw.Network{Prefix: r.dest}
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
_, err := manager.AddFilterRule(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -1016,7 +1012,7 @@ func BenchmarkMSSClamping(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -1081,7 +1077,7 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -1136,7 +1132,7 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, manager)
|
||||
|
||||
@@ -496,40 +496,32 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
// add general accept rule for the same IP to test drop rule precedence
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
rules, err := manager.AddFilterRule(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
require.NotNil(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
rules, err := manager.AddFilterRule(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
||||
tc.ruleProto,
|
||||
tc.ruleSrcPort,
|
||||
tc.ruleDstPort,
|
||||
tc.ruleAction,
|
||||
"",
|
||||
)
|
||||
tc.ruleAction)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
require.NotNil(t, rules)
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
@@ -557,7 +549,7 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
@@ -672,22 +664,18 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
||||
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
require.NotNil(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
@@ -800,7 +788,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NotNil(tb, manager)
|
||||
@@ -1405,7 +1393,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.rule.action == fw.ActionDrop {
|
||||
// add general accept rule to test drop rule
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
@@ -1415,13 +1403,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
require.NotEmpty(t, rule)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
})
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
tc.rule.sources,
|
||||
tc.rule.dest,
|
||||
@@ -1431,10 +1419,10 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
tc.rule.action,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
require.NotEmpty(t, rule)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
})
|
||||
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
@@ -1602,9 +1590,9 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var rules []fw.Rule
|
||||
var addedRules []fw.Rule
|
||||
for _, r := range tc.rules {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
r.sources,
|
||||
r.dest,
|
||||
@@ -1615,12 +1603,12 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
rules = append(rules, rule)
|
||||
addedRules = append(addedRules, rule)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
for _, rule := range addedRules {
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1646,7 +1634,7 @@ func TestRouteACLSet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -1655,7 +1643,7 @@ func TestRouteACLSet(t *testing.T) {
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
// Add rule that uses the set (initially empty)
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
@@ -1689,7 +1677,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||
manager := setupRoutedManager(t, "10.10.0.100/16")
|
||||
|
||||
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
||||
_, err := manager.AddRouteFiltering(
|
||||
_, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: v6Dst},
|
||||
@@ -1700,7 +1688,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add rule first time
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -42,7 +42,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
// Add the same rule again
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -74,7 +74,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
|
||||
// Add first rule
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
@@ -86,7 +86,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add different rule (different destination)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-2"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||
@@ -115,7 +115,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -132,7 +132,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
require.True(t, pass, "Traffic should pass with rule in place")
|
||||
|
||||
// Re-add same rule (simulates network map update)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -147,7 +147,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||
// would remove the only matching rule and cause a traffic gap.
|
||||
if rule1.ID() != rule2.ID() {
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
err = manager.DeleteFilterRule(rule1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -156,6 +156,59 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
"Traffic should still pass after rule update - no gap should occur")
|
||||
}
|
||||
|
||||
// TestBlockInvalidRoutedDualStack verifies that when the interface has an
|
||||
// IPv6 overlay address, blockInvalidRouted installs a block rule for both
|
||||
// the v4 and v6 WG prefixes and that routed traffic to the v6 prefix is
|
||||
// denied. The v4-only soft-skip path is covered by
|
||||
// TestBlockInvalidRoutedIdempotent, whose mock leaves IPv6Net invalid.
|
||||
func TestBlockInvalidRoutedDualStack(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
wgNet6 := netip.MustParsePrefix("fd00:1234::1/64")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
IPv6: wgNet6.Addr(),
|
||||
IPv6Net: wgNet6,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
rules, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rules, 2, "dual-stack interface must produce a v4 and a v6 block rule")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
assert.Equal(t, 2, ruleCount, "should have one block rule per family")
|
||||
|
||||
// v6 routed traffic to the WG prefix must be denied.
|
||||
srcIP := netip.MustParseAddr("2001:db8::1")
|
||||
dstIP := netip.MustParseAddr("fd00:1234::50")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||
assert.False(t, pass, "block rule should deny routed traffic to the v6 WG prefix")
|
||||
}
|
||||
|
||||
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||
// returns the same rule without duplicating.
|
||||
@@ -182,7 +235,7 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -245,7 +298,7 @@ func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -274,7 +327,7 @@ func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||
|
||||
// Simulate 5 network map updates with the same route rule
|
||||
for i := 0; i < 5; i++ {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -304,7 +357,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add same rule twice
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -315,7 +368,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -329,7 +382,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||
|
||||
// Delete using first reference
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
err = manager.DeleteFilterRule(rule1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify traffic no longer passes
|
||||
@@ -364,7 +417,7 @@ func setupTestManager(t *testing.T) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -89,7 +89,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
func TestManagerAddFilterRule(t *testing.T) {
|
||||
isSetFilterCalled := false
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error {
|
||||
@@ -98,7 +98,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -109,7 +109,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
rule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -131,7 +131,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -142,63 +142,33 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||
rule2, err := m.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, proto, nil, port, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check rules exist in appropriate maps
|
||||
for _, r := range rule2 {
|
||||
peerRule, ok := r.(*PeerRule)
|
||||
if !ok {
|
||||
t.Errorf("rule should be a PeerRule")
|
||||
continue
|
||||
}
|
||||
// Check if rule exists in deny or allow maps based on action
|
||||
var found bool
|
||||
if peerRule.drop {
|
||||
_, found = m.incomingDenyRules[ip][r.ID()]
|
||||
} else {
|
||||
_, found = m.incomingRules[ip][r.ID()]
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("rule2 is not in the expected rules map")
|
||||
peerRule, ok := rule2.(*PeerRule)
|
||||
require.True(t, ok, "rule should be a PeerRule")
|
||||
|
||||
inMap := func() bool {
|
||||
if peerRule.action == fw.ActionDrop {
|
||||
return findRuleByID(m.incomingDenyRules, ip, rule2.ID())
|
||||
}
|
||||
return findRuleByID(m.incomingAcceptRules, ip, rule2.ID())
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
err = m.DeletePeerRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
require.True(t, inMap(), "rule2 should be in the expected rules list")
|
||||
|
||||
// Check rules are removed from appropriate maps
|
||||
for _, r := range rule2 {
|
||||
peerRule, ok := r.(*PeerRule)
|
||||
if !ok {
|
||||
t.Errorf("rule should be a PeerRule")
|
||||
continue
|
||||
}
|
||||
// Check if rule is removed from deny or allow maps based on action
|
||||
var found bool
|
||||
if peerRule.drop {
|
||||
_, found = m.incomingDenyRules[ip][r.ID()]
|
||||
} else {
|
||||
_, found = m.incomingRules[ip][r.ID()]
|
||||
}
|
||||
if found {
|
||||
t.Errorf("rule2 should be removed from the rules map")
|
||||
}
|
||||
}
|
||||
require.NoError(t, m.DeleteFilterRule(rule2), "failed to delete rule")
|
||||
|
||||
require.False(t, inMap(), "rule2 should be removed from the rules list")
|
||||
}
|
||||
|
||||
func TestSetUDPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
}, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
@@ -222,7 +192,7 @@ func TestSetUDPPacketHook(t *testing.T) {
|
||||
func TestSetTCPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
}, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
@@ -250,7 +220,7 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
@@ -260,36 +230,34 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Add multiple deny rules for different ports
|
||||
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
rule1, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||
rule2, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||
|
||||
// Delete the first deny rule
|
||||
err = m.DeletePeerRule(rule1[0])
|
||||
err = m.DeleteFilterRule(rule1)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount = len(m.incomingDenyRules[addr])
|
||||
denyCount = countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||
|
||||
// Delete the second deny rule
|
||||
err = m.DeletePeerRule(rule2[0])
|
||||
err = m.DeleteFilterRule(rule2)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, exists := m.incomingDenyRules[addr]
|
||||
exists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
||||
m.mutex.RUnlock()
|
||||
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||
require.False(t, exists, "Deny rules should be cleaned up when empty")
|
||||
}
|
||||
|
||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||
@@ -299,7 +267,7 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
@@ -311,27 +279,21 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||
// Simulate 10 network map updates: add rule, delete old, add new
|
||||
for i := 0; i < 10; i++ {
|
||||
// Add a deny rule
|
||||
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add an allow rule
|
||||
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
allowRules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete them (simulating ACL manager cleanup)
|
||||
for _, r := range rules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
for _, r := range allowRules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
require.NoError(t, m.DeleteFilterRule(rules))
|
||||
require.NoError(t, m.DeleteFilterRule(allowRules))
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||
@@ -345,7 +307,7 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
@@ -354,41 +316,39 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
|
||||
// Add allow rule for port 80
|
||||
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
allowRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add deny rule for port 22
|
||||
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
denyRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
m.mutex.RLock()
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
|
||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||
|
||||
// Delete allow rule should not affect deny rule
|
||||
err = m.DeletePeerRule(allowRule[0])
|
||||
err = m.DeleteFilterRule(allowRule)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||
denyCountAfter := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||
|
||||
// Delete deny rule
|
||||
err = m.DeletePeerRule(denyRule[0])
|
||||
err = m.DeleteFilterRule(denyRule)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, denyExists := m.incomingDenyRules[addr]
|
||||
_, allowExists := m.incomingRules[addr]
|
||||
denyExists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
||||
allowExists := countRulesForAddr(m.incomingAcceptRules, addr) > 0
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.False(t, denyExists, "Deny rules should be empty")
|
||||
@@ -400,7 +360,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -411,7 +371,7 @@ func TestManagerReset(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -423,7 +383,7 @@ func TestManagerReset(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
|
||||
if len(m.incomingAcceptRules) != 0 || len(m.incomingDenyRules) != 0 {
|
||||
t.Errorf("rules are not empty")
|
||||
}
|
||||
}
|
||||
@@ -439,7 +399,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -449,7 +409,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
proto := fw.ProtocolUDP
|
||||
action := fw.ActionAccept
|
||||
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -502,7 +462,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(iface, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
@@ -521,7 +481,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
}, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close()
|
||||
@@ -606,7 +566,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -621,7 +581,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
@@ -633,7 +593,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
}, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
@@ -845,7 +805,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -858,7 +818,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
@@ -931,7 +891,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -939,7 +899,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
@@ -1051,7 +1011,7 @@ func TestMSSClamping(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, 1280)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, 1280)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -1243,7 +1203,7 @@ func TestShouldForward(t *testing.T) {
|
||||
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -1358,7 +1318,7 @@ func TestShouldForward(t *testing.T) {
|
||||
|
||||
// Re-create manager to pick up the new address with IPv6
|
||||
require.NoError(t, manager.Close(nil))
|
||||
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err = Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
v6Cases := []struct {
|
||||
|
||||
@@ -66,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -126,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -198,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -310,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
|
||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -483,7 +483,7 @@ func BenchmarkPortDNAT(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
func TestPortDNATBasic(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -51,7 +51,7 @@ func TestPortDNATBasic(t *testing.T) {
|
||||
func TestPortDNATMultipleRules(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -106,7 +106,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
func TestDNATMappingManagement(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -154,7 +154,7 @@ func TestDNATMappingManagement(t *testing.T) {
|
||||
func TestInboundPortDNAT(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -204,7 +204,7 @@ func TestInboundPortDNAT(t *testing.T) {
|
||||
func TestInboundPortDNATNegative(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
|
||||
327
client/firewall/uspfilter/peer_acl_bench_test.go
Normal file
327
client/firewall/uspfilter/peer_acl_bench_test.go
Normal file
@@ -0,0 +1,327 @@
|
||||
//go:build uspbench
|
||||
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// BenchmarkPeerACLMatch measures the per-packet cost of the peer ACL
|
||||
// matcher (peerACLsBlock) across realistic shapes: M distinct policy
|
||||
// rules, each with K source peers in its set.
|
||||
//
|
||||
// With the reverse-source index, miss cost is independent of M and
|
||||
// hit cost grows only with the number of rules touching a single
|
||||
// srcIP, not with total rule count.
|
||||
func BenchmarkPeerACLMatch(b *testing.B) {
|
||||
shapes := []struct{ M, K int }{
|
||||
{1, 100}, {10, 100}, {50, 100}, {100, 100}, {100, 1000},
|
||||
}
|
||||
families := []struct {
|
||||
name string
|
||||
v6 bool
|
||||
}{{"v4", false}, {"v6", true}}
|
||||
|
||||
for _, fam := range families {
|
||||
for _, s := range shapes {
|
||||
b.Run(fmt.Sprintf("%s/M=%d/K=%d/hit", fam.name, s.M, s.K), func(b *testing.B) {
|
||||
runPeerACLBench(b, s.M, s.K, true, fam.v6)
|
||||
})
|
||||
b.Run(fmt.Sprintf("%s/M=%d/K=%d/miss", fam.name, s.M, s.K), func(b *testing.B) {
|
||||
runPeerACLBench(b, s.M, s.K, false, fam.v6)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runPeerACLBench(b *testing.B, m, k int, hit, v6 bool) {
|
||||
log.SetOutput(io.Discard) // keep manager logs out of the benchmark output
|
||||
|
||||
// Miss packets are dropped, so they always traverse the full peer
|
||||
// ACL matcher (every bucket) without short-circuiting and without
|
||||
// touching conntrack. Disable conntrack for the miss case so it
|
||||
// measures the matcher, not established-state lookups. The hit case
|
||||
// keeps conntrack on: an accepted packet reaches trackInbound, which
|
||||
// needs the trackers conntrack creates.
|
||||
if !hit {
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
}
|
||||
|
||||
bits := 32
|
||||
genPkt := generatePacket
|
||||
addrs := uniqueAddrs
|
||||
if v6 {
|
||||
bits = 128
|
||||
genPkt = generatePacket6
|
||||
addrs = uniqueAddrs6
|
||||
}
|
||||
|
||||
// dstIP must be a local IP so filterInbound takes the local-traffic
|
||||
// path (handleLocalTraffic → peerACLsBlock) we are measuring; an
|
||||
// address the manager doesn't own would be treated as routed and
|
||||
// short-circuit before the peer matcher.
|
||||
dstIP := addrs(1, 2)[0]
|
||||
mockAddr := wgaddr.Address{IP: dstIP, Network: netip.PrefixFrom(dstIP, bits)}
|
||||
if v6 {
|
||||
// The local-IP manager needs a valid v4 address too; expose the v6
|
||||
// dst as the interface's IPv6 so IsLocalIP recognizes it.
|
||||
mockAddr = wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: dstIP,
|
||||
IPv6Net: netip.PrefixFrom(dstIP, bits),
|
||||
}
|
||||
}
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address { return mockAddr },
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() { require.NoError(b, manager.Close(nil)) })
|
||||
|
||||
// Generate M policies × K source peers, all distinct.
|
||||
all := addrs(m*k, 1)
|
||||
for i := 0; i < m; i++ {
|
||||
sources := make([]netip.Prefix, k)
|
||||
for j, a := range all[i*k : (i+1)*k] {
|
||||
sources[j] = netip.PrefixFrom(a, bits)
|
||||
}
|
||||
_, err := manager.AddFilterRule(
|
||||
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
||||
fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Hit: cycle through real sources, picking the matching policy's port.
|
||||
// Miss: a source from a disjoint range, port 80 (matches no policy).
|
||||
var pktFn func(i int) []byte
|
||||
if hit {
|
||||
pktFn = func(i int) []byte {
|
||||
policy := i % m
|
||||
src := all[policy*k+(i%k)]
|
||||
return genPkt(b, src.AsSlice(), dstIP.AsSlice(),
|
||||
uint16(1024+i%60000), uint16(80+policy), layers.IPProtocolTCP)
|
||||
}
|
||||
} else {
|
||||
miss := addrs(4096, 99)
|
||||
pktFn = func(i int) []byte {
|
||||
return genPkt(b, miss[i%len(miss)].AsSlice(), dstIP.AsSlice(),
|
||||
uint16(1024+i%60000), 80, layers.IPProtocolTCP)
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-build a pool to avoid allocations dominating the measurement.
|
||||
pool := make([][]byte, 1024)
|
||||
for i := range pool {
|
||||
pool[i] = pktFn(i)
|
||||
}
|
||||
|
||||
// Confirm the matcher is actually exercised: a hit packet must be
|
||||
// allowed and a miss packet dropped. Without this the benchmark
|
||||
// could silently time the routed early-return instead.
|
||||
require.Equal(b, !hit, manager.filterInbound(pool[0], 0),
|
||||
"benchmark must reach the peer ACL matcher")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(pool[i%len(pool)], 0)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPeerACLIndexMemory reports the resident memory cost of
|
||||
// the source-keyed index across realistic deployment shapes. Two
|
||||
// dimensions matter: (M, K), the number of policies × peers-per-policy,
|
||||
// and overlap, the fraction of peers shared between policies.
|
||||
//
|
||||
// The output uses ReportMetric("bytes/rule") so the cost can be
|
||||
// compared across shapes directly. Total bytes = bytes/rule * M.
|
||||
func BenchmarkPeerACLIndexMemory(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
M, K int
|
||||
overlapFrac float64 // 0 = disjoint per-policy sources, 1 = all share the same pool
|
||||
}{
|
||||
{"M=10/K=100/disjoint", 10, 100, 0},
|
||||
{"M=100/K=100/disjoint", 100, 100, 0},
|
||||
{"M=100/K=1000/disjoint", 100, 1000, 0},
|
||||
{"M=100/K=1000/overlap=0.5", 100, 1000, 0.5},
|
||||
{"M=100/K=1000/overlap=1.0", 100, 1000, 1.0},
|
||||
{"M=1000/K=100/overlap=1.0", 1000, 100, 1.0},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
b.Run(c.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
mgr, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
|
||||
populateIndexedRules(b, mgr, c.M, c.K, c.overlapFrac)
|
||||
|
||||
runtime.GC()
|
||||
var ms runtime.MemStats
|
||||
runtime.ReadMemStats(&ms)
|
||||
before := ms.HeapAlloc
|
||||
|
||||
// Drop the manager's external roots so we can isolate
|
||||
// the index cost. We hold the manager itself live; the
|
||||
// index is what we measure on the second pass.
|
||||
mgr.incomingAcceptIndex.reset()
|
||||
mgr.incomingDenyIndex.reset()
|
||||
mgr.incomingAcceptRules = mgr.incomingAcceptRules[:0]
|
||||
mgr.incomingDenyRules = mgr.incomingDenyRules[:0]
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&ms)
|
||||
after := ms.HeapAlloc
|
||||
|
||||
delta := int64(before) - int64(after)
|
||||
if delta < 0 {
|
||||
delta = 0
|
||||
}
|
||||
b.ReportMetric(float64(delta)/float64(c.M), "bytes/rule")
|
||||
b.ReportMetric(float64(delta), "bytes/total")
|
||||
|
||||
require.NoError(b, mgr.Close(nil))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func populateIndexedRules(b *testing.B, mgr *Manager, m, k int, overlapFrac float64) {
|
||||
b.Helper()
|
||||
pool := uniqueAddrs(k+m*k, 1) // big enough universe
|
||||
sharedLen := int(float64(k) * overlapFrac)
|
||||
if sharedLen > k {
|
||||
sharedLen = k
|
||||
}
|
||||
shared := pool[:sharedLen]
|
||||
uniquePool := pool[sharedLen:]
|
||||
|
||||
for i := 0; i < m; i++ {
|
||||
sources := make([]netip.Prefix, 0, k)
|
||||
for _, a := range shared {
|
||||
sources = append(sources, netip.PrefixFrom(a, 32))
|
||||
}
|
||||
// each policy gets (k-sharedLen) addresses unique to it from the unique pool
|
||||
unique := uniquePool[i*(k-sharedLen) : (i+1)*(k-sharedLen)]
|
||||
for _, a := range unique {
|
||||
sources = append(sources, netip.PrefixFrom(a, 32))
|
||||
}
|
||||
_, err := mgr.AddFilterRule(
|
||||
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
||||
fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
}
|
||||
|
||||
// uniqueAddrs returns n distinct addrs. Seeds 1, 2 are used for
|
||||
// policy sources / dst; seed 99 puts misses in 10/8.
|
||||
func uniqueAddrs(n int, seed int64) []netip.Addr {
|
||||
out := make([]netip.Addr, 0, n)
|
||||
seen := make(map[netip.Addr]struct{}, n)
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
miss := seed == 99
|
||||
for len(out) < n {
|
||||
var b [4]byte
|
||||
if miss {
|
||||
b[0] = 10
|
||||
b[1] = byte(r.Intn(256))
|
||||
} else {
|
||||
b[0] = 100
|
||||
b[1] = byte(64 + r.Intn(63))
|
||||
}
|
||||
b[2] = byte(r.Intn(256))
|
||||
b[3] = byte(1 + r.Intn(254))
|
||||
a := netip.AddrFrom4(b)
|
||||
if _, ok := seen[a]; ok {
|
||||
continue
|
||||
}
|
||||
seen[a] = struct{}{}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// uniqueAddrs6 mirrors uniqueAddrs for IPv6: sources come from the ULA
|
||||
// range fd00::/8, the miss set (seed 99) from 2001:db8::/32 so it is
|
||||
// disjoint from any source.
|
||||
func uniqueAddrs6(n int, seed int64) []netip.Addr {
|
||||
out := make([]netip.Addr, 0, n)
|
||||
seen := make(map[netip.Addr]struct{}, n)
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
miss := seed == 99
|
||||
for len(out) < n {
|
||||
var b [16]byte
|
||||
if miss {
|
||||
b[0], b[1], b[2], b[3] = 0x20, 0x01, 0x0d, 0xb8
|
||||
} else {
|
||||
b[0] = 0xfd
|
||||
}
|
||||
for x := 8; x < 16; x++ {
|
||||
b[x] = byte(r.Intn(256))
|
||||
}
|
||||
a := netip.AddrFrom16(b)
|
||||
if _, ok := seen[a]; ok {
|
||||
continue
|
||||
}
|
||||
seen[a] = struct{}{}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// generatePacket6 builds an IPv6 TCP/UDP packet, mirroring
|
||||
// generatePacket for the v4 case.
|
||||
func generatePacket6(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
|
||||
b.Helper()
|
||||
|
||||
ipv6 := &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: protocol,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
var transportLayer gopacket.SerializableLayer
|
||||
switch protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
}
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv6))
|
||||
transportLayer = tcp
|
||||
case layers.IPProtocolUDP:
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
}
|
||||
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv6))
|
||||
transportLayer = udp
|
||||
}
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv6, transportLayer, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
149
client/firewall/uspfilter/peer_acl_dedup_test.go
Normal file
149
client/firewall/uspfilter/peer_acl_dedup_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
func newTestManager(t *testing.T) *Manager {
|
||||
t.Helper()
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
return m
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_DeduplicatesIdenticalRules verifies that adding
|
||||
// the same peer rule twice does not create two backing rules. The acl
|
||||
// manager keys its own cache, but the firewall backend must be
|
||||
// idempotent on its own so a double-apply cannot leak rules, matching
|
||||
// the route path and the kernel backends.
|
||||
func TestAddPeerFiltering_DeduplicatesIdenticalRules(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "first add")
|
||||
|
||||
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "second add")
|
||||
|
||||
assert.Equal(t, first.ID(), second.ID(), "duplicate add should return the same rule id")
|
||||
assert.Len(t, m.incomingDenyRules, 1, "duplicate add must not create a second backing rule")
|
||||
}
|
||||
|
||||
// TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves locks the
|
||||
// backend's no-refcount contract: a content key installed twice is
|
||||
// still one rule, and the first DeleteFilterRule removes it. The
|
||||
// backend does not refcount, so balance is the caller's job (it keys
|
||||
// its tracking by the returned id and deletes once per key). If this
|
||||
// ever silently grew a refcount, the acl manager's delete accounting
|
||||
// would diverge from the kernel.
|
||||
func TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "first add")
|
||||
|
||||
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "second add")
|
||||
require.Equal(t, first.ID(), second.ID(), "dedup to one rule")
|
||||
require.Len(t, m.incomingDenyRules, 1, "still one backing rule after duplicate add")
|
||||
|
||||
require.NoError(t, m.DeleteFilterRule(first), "delete once")
|
||||
assert.Empty(t, m.incomingDenyRules, "single delete removes the backing rule (no refcount)")
|
||||
assert.NotContains(t, m.peerRulesMap, first.ID(), "dedup map entry cleared")
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_DeterministicID verifies the peer rule id is a
|
||||
// content hash, not a random UUID: identical inputs produce the same id
|
||||
// across independent managers. A random id breaks caller-side dedup.
|
||||
func TestAddPeerFiltering_DeterministicID(t *testing.T) {
|
||||
ip := net.ParseIP("10.0.0.5")
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
|
||||
m1 := newTestManager(t)
|
||||
r1, err := m1.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
m2 := newTestManager(t)
|
||||
r2, err := m2.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, r1.ID(), r2.ID(), "same inputs must produce the same rule id")
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_DistinctRulesNotDeduped verifies that rules
|
||||
// differing only by port are kept separate.
|
||||
func TestAddPeerFiltering_DistinctRulesNotDeduped(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
action := fw.ActionAccept
|
||||
|
||||
r80, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{80}}, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
r443, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{443}}, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, r80.ID(), r443.ID(), "different ports must produce different rule ids")
|
||||
assert.Len(t, m.incomingAcceptRules, 2, "distinct rules must both be stored")
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_SourceVsDestPortNotDeduped verifies that a rule
|
||||
// matching on source port and one matching on destination port for the
|
||||
// same selector do not collide: the port lands in a different slot, so
|
||||
// the content key must differ.
|
||||
func TestAddPeerFiltering_SourceVsDestPortNotDeduped(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
|
||||
dPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
sPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, port, nil, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, dPortRule.ID(), sPortRule.ID(), "source-port and dest-port matches must produce different rule ids")
|
||||
}
|
||||
|
||||
// TestAddFilterRule_EmptySourcesRejected verifies that an empty source
|
||||
// list is rejected rather than treated as "match any". "Match any" must
|
||||
// be an explicit /0, so a zeroed list can never silently widen a rule to
|
||||
// every source.
|
||||
func TestAddFilterRule_EmptySourcesRejected(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
|
||||
_, err := m.AddFilterRule(nil, nil, fw.Network{}, proto, nil, port, fw.ActionAccept)
|
||||
require.ErrorIs(t, err, fw.ErrNoSources, "empty sources must be rejected")
|
||||
assert.Empty(t, m.incomingAcceptRules, "no rule should be stored for empty sources")
|
||||
}
|
||||
105
client/firewall/uspfilter/peer_acl_ipv6_test.go
Normal file
105
client/firewall/uspfilter/peer_acl_ipv6_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func newV6TestManager(t *testing.T, localV6 string) *Manager {
|
||||
t.Helper()
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
IPv6: netip.MustParseAddr(localV6),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
||||
return m
|
||||
}
|
||||
|
||||
func v6UDPPacket(t *testing.T, src, dst string, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
ip6 := &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
SrcIP: net.ParseIP(src),
|
||||
DstIP: net.ParseIP(dst),
|
||||
}
|
||||
udp := &layers.UDP{SrcPort: 51334, DstPort: layers.UDPPort(dstPort)}
|
||||
require.NoError(t, udp.SetNetworkLayerForChecksum(ip6))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(t, gopacket.SerializeLayers(buf, opts, ip6, udp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// TestPeerACL_IPv6HostRule verifies the source index resolves /128 v6
|
||||
// rules: a matching v6 source is accepted, a non-matching one is
|
||||
// denied by the default. This is the end-to-end proof that the index
|
||||
// is not v4-only.
|
||||
func TestPeerACL_IPv6HostRule(t *testing.T) {
|
||||
m := newV6TestManager(t, "fd00::100")
|
||||
|
||||
src := net.ParseIP("fd00::1")
|
||||
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "add v6 accept rule")
|
||||
|
||||
require.False(t, m.filterInbound(v6UDPPacket(t, "fd00::1", "fd00::100", 53), 0),
|
||||
"v6 packet from the allowed /128 source must be accepted")
|
||||
require.True(t, m.filterInbound(v6UDPPacket(t, "fd00::2", "fd00::100", 53), 0),
|
||||
"v6 packet from an unlisted source must be denied by default")
|
||||
}
|
||||
|
||||
// TestPeerACL_IPv6IndexBuckets verifies that v6 sources land in the
|
||||
// right index bucket: a /128 in bySource keyed by its address, and
|
||||
// coarser prefixes (including ::/0) in the nonHost slice.
|
||||
func TestPeerACL_IPv6IndexBuckets(t *testing.T) {
|
||||
m := newV6TestManager(t, "fd00::100")
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
|
||||
host := netip.MustParseAddr("fd00::1")
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(host, 128)}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, m.incomingAcceptIndex.bySource, host, "/128 v6 source must be indexed by address")
|
||||
|
||||
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("fd00:dead::/64")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.incomingAcceptIndex.nonHost, 1, "coarser v6 prefix must land in nonHost")
|
||||
|
||||
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("::/0")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.incomingAcceptIndex.nonHost, 2, "::/0 source must also land in nonHost")
|
||||
}
|
||||
|
||||
// TestPeerACL_IPv4MappedSourceNormalized verifies a v4-mapped v6
|
||||
// source prefix is normalized to v4 so a plain v4 packet matches it.
|
||||
func TestPeerACL_IPv4MappedSourceNormalized(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
mapped := netip.MustParseAddr("::ffff:192.168.1.1")
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(mapped, mapped.BitLen())}, fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
|
||||
v4 := netip.MustParseAddr("192.168.1.1")
|
||||
assert.Contains(t, m.incomingAcceptIndex.bySource, v4, "v4-mapped v6 source must be indexed as plain v4")
|
||||
}
|
||||
139
client/firewall/uspfilter/peer_index.go
Normal file
139
client/firewall/uspfilter/peer_index.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// peerRuleIndex is the source-side dispatcher consulted on the packet
|
||||
// hot path. It splits rules into two buckets by the shape of their
|
||||
// source list:
|
||||
//
|
||||
// - bySource: every source is a host prefix (/32 for v4, /128 for
|
||||
// v6). Keyed by the concrete source address, so a hit guarantees
|
||||
// the source filter passes and the matcher goes straight to
|
||||
// proto/port checks. This is the common case for peer ACLs.
|
||||
// - nonHost: any source list with a prefix coarser than a host,
|
||||
// including a /0 "match any". Walked linearly with a per-rule
|
||||
// Contains() check. Expected small or empty for typical peer ACLs.
|
||||
//
|
||||
// Maintained incrementally by add/remove, never rebuilt.
|
||||
type peerRuleIndex struct {
|
||||
bySource map[netip.Addr][]*PeerRule
|
||||
nonHost []*PeerRule
|
||||
}
|
||||
|
||||
func (i *peerRuleIndex) add(r *PeerRule) {
|
||||
if hasNonHostSource(r) {
|
||||
i.nonHost = append(i.nonHost, r)
|
||||
return
|
||||
}
|
||||
if i.bySource == nil {
|
||||
i.bySource = make(map[netip.Addr][]*PeerRule)
|
||||
}
|
||||
for a := range r.sourceAddrs {
|
||||
i.bySource[a] = append(i.bySource[a], r)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *peerRuleIndex) remove(r *PeerRule) {
|
||||
if hasNonHostSource(r) {
|
||||
i.nonHost = slices.DeleteFunc(i.nonHost, eqRule(r))
|
||||
return
|
||||
}
|
||||
if i.bySource == nil {
|
||||
return
|
||||
}
|
||||
for a := range r.sourceAddrs {
|
||||
entries := slices.DeleteFunc(i.bySource[a], eqRule(r))
|
||||
if len(entries) == 0 {
|
||||
delete(i.bySource, a)
|
||||
} else {
|
||||
i.bySource[a] = entries
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *peerRuleIndex) reset() {
|
||||
i.bySource = nil
|
||||
i.nonHost = i.nonHost[:0]
|
||||
}
|
||||
|
||||
// match returns the first rule matching src and the decoded packet.
|
||||
// Host rules are found by direct map lookup; nonHost rules need a
|
||||
// per-rule source Contains() check, except match-any (/0) rules which
|
||||
// apply to every source regardless of family (a v4 /0 also matches v6).
|
||||
// Within either bucket the matcher runs the proto/port filter.
|
||||
func (i *peerRuleIndex) match(src netip.Addr, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
for _, rule := range i.bySource[src] {
|
||||
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
||||
return id, drop, true
|
||||
}
|
||||
}
|
||||
for _, rule := range i.nonHost {
|
||||
if !rule.matchAny && !prefixesContain(rule.sources, src) {
|
||||
continue
|
||||
}
|
||||
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
||||
return id, drop, true
|
||||
}
|
||||
}
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
func eqRule(target *PeerRule) func(*PeerRule) bool {
|
||||
return func(p *PeerRule) bool { return p == target }
|
||||
}
|
||||
|
||||
// hasNonHostSource reports whether the rule has any source prefix
|
||||
// that is not a single host address. Called only at add/remove time,
|
||||
// not on the packet path.
|
||||
func hasNonHostSource(r *PeerRule) bool {
|
||||
for _, p := range r.sources {
|
||||
if p.Bits() != p.Addr().BitLen() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchProto applies the proto/port half of a rule against the
|
||||
// decoded packet. Source matching is the caller's responsibility.
|
||||
func matchProto(rule *PeerRule, d *decoder, payloadLayer gopacket.LayerType) ([]byte, bool, bool) {
|
||||
drop := rule.action == firewall.ActionDrop
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||
return nil, false, false
|
||||
}
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if portsMatch(rule.srcPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
if portsMatch(rule.srcPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
func prefixesContain(sources []netip.Prefix, src netip.Addr) bool {
|
||||
for _, p := range sources {
|
||||
if p.Contains(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -10,24 +10,49 @@ import (
|
||||
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
ip netip.Addr
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
id firewall.RuleID
|
||||
mgmtId []byte
|
||||
// sources is the canonical list of source prefixes this rule
|
||||
// matches against.
|
||||
sources []netip.Prefix
|
||||
// sourceAddrs is a fast-path membership set for host-prefix
|
||||
// sources (/32 v4, /128 v6). Populated alongside sources;
|
||||
// consulted before falling back to prefix scan.
|
||||
sourceAddrs map[netip.Addr]struct{}
|
||||
// matchAny is true when sources covers everything (0.0.0.0/0,
|
||||
// ::/0). In that case neither sourceAddrs nor sources need to be
|
||||
// consulted.
|
||||
matchAny bool
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// matchesSource reports whether the given source address is covered
|
||||
// by this rule's source list.
|
||||
func (r *PeerRule) matchesSource(src netip.Addr) bool {
|
||||
if r.matchAny {
|
||||
return true
|
||||
}
|
||||
if _, ok := r.sourceAddrs[src]; ok {
|
||||
return true
|
||||
}
|
||||
for _, p := range r.sources {
|
||||
if p.Contains(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
func (r *PeerRule) ID() string {
|
||||
func (r *PeerRule) ID() firewall.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
id firewall.RuleID
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
@@ -39,6 +64,6 @@ type RouteRule struct {
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
func (r *RouteRule) ID() string {
|
||||
func (r *RouteRule) ID() firewall.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
50
client/firewall/uspfilter/testhelpers_test.go
Normal file
50
client/firewall/uspfilter/testhelpers_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// countRulesForAddr reports how many rules in the given slice match
|
||||
// the supplied source address.
|
||||
func countRulesForAddr(rules peerRules, src netip.Addr) int {
|
||||
n := 0
|
||||
for _, r := range rules {
|
||||
if r.matchesSource(src) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// findRuleByID returns true if the rules slice contains a rule with
|
||||
// the given id whose source set covers src.
|
||||
func findRuleByID(rules peerRules, src netip.Addr, id firewall.RuleID) bool {
|
||||
for _, r := range rules {
|
||||
if r.id == id && r.matchesSource(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pfx converts a single net.IP into the []netip.Prefix form
|
||||
// AddFilterRule expects. A nil or unspecified address becomes a /0
|
||||
// ("match any") prefix in the matching family; any other address
|
||||
// becomes its /32 (or /128) host prefix.
|
||||
func pfx(ip net.IP) []netip.Prefix {
|
||||
if ip == nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
a, _ := netip.AddrFromSlice(ip)
|
||||
a = a.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||
}
|
||||
@@ -285,6 +285,14 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
trace.SourceIP = srcIP
|
||||
trace.DestinationIP = dstIP
|
||||
|
||||
// A fragment or otherwise truncated packet has no transport layer.
|
||||
// The inbound datapath drops these via isValidPacket; the tracer must
|
||||
// guard explicitly since every downstream stage reads d.decoded[1].
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageReceived, "Packet has no transport layer (fragment or unsupported protocol)", false)
|
||||
return trace
|
||||
}
|
||||
|
||||
// Determine protocol and ports
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !statefulMode {
|
||||
@@ -97,7 +97,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -121,7 +121,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -150,7 +150,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -178,7 +178,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -205,7 +205,7 @@ func TestTracePacket(t *testing.T) {
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -231,7 +231,7 @@ func TestTracePacket(t *testing.T) {
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -332,7 +332,7 @@ func TestTracePacket(t *testing.T) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -355,7 +355,7 @@ func TestTracePacket(t *testing.T) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -379,7 +379,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -423,7 +423,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
!define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
|
||||
!define INSTALLER_NAME "netbird-installer.exe"
|
||||
!define MAIN_APP_EXE "Netbird"
|
||||
!define ICON "ui\\build\\windows\\icon.ico"
|
||||
!define ICON "ui\\assets\\netbird.ico"
|
||||
!define BANNER "ui\\build\\banner.bmp"
|
||||
!define LICENSE_DATA "..\\LICENSE"
|
||||
|
||||
@@ -280,43 +280,6 @@ CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
SectionEnd
|
||||
|
||||
# Install the Microsoft Edge WebView2 runtime if it isn't already present.
|
||||
# Macro adapted from Wails3's NSIS template (wails_tools.nsh): a registry
|
||||
# probe followed by a silent install of the embedded evergreen bootstrapper.
|
||||
# The MicrosoftEdgeWebview2Setup.exe payload is staged next to this script
|
||||
# by the sign-pipelines build step (`wails3 generate webview2bootstrapper`).
|
||||
!macro nb.webview2runtime
|
||||
SetRegView 64
|
||||
# Per-machine install marker — populated when the runtime ships with
|
||||
# Edge or has been installed by an admin previously.
|
||||
ReadRegStr $0 HKLM "SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
|
||||
${If} $0 != ""
|
||||
Goto webview2_ok
|
||||
${EndIf}
|
||||
# Per-user fallback for HKCU installs.
|
||||
ReadRegStr $0 HKCU "Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
|
||||
${If} $0 != ""
|
||||
Goto webview2_ok
|
||||
${EndIf}
|
||||
|
||||
SetDetailsPrint both
|
||||
DetailPrint "Installing: WebView2 Runtime"
|
||||
SetDetailsPrint listonly
|
||||
|
||||
InitPluginsDir
|
||||
CreateDirectory "$pluginsdir\webview2bootstrapper"
|
||||
SetOutPath "$pluginsdir\webview2bootstrapper"
|
||||
File "MicrosoftEdgeWebview2Setup.exe"
|
||||
ExecWait '"$pluginsdir\webview2bootstrapper\MicrosoftEdgeWebview2Setup.exe" /silent /install'
|
||||
|
||||
SetDetailsPrint both
|
||||
webview2_ok:
|
||||
!macroend
|
||||
|
||||
Section -WebView2
|
||||
!insertmacro nb.webview2runtime
|
||||
SectionEnd
|
||||
|
||||
Section -Post
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install'
|
||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service start'
|
||||
@@ -363,9 +326,9 @@ DetailPrint "Deleting application files..."
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
# Legacy: pre-Wails installs shipped opengl32.dll (Mesa3D for Fyne); remove
|
||||
# any leftover copy on uninstall so old upgrades don't leave it behind.
|
||||
!if ${ARCH} == "amd64"
|
||||
Delete "$INSTDIR\opengl32.dll"
|
||||
!endif
|
||||
DetailPrint "Removing application directory..."
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
|
||||
190
client/internal/acl/dispatch_test.go
Normal file
190
client/internal/acl/dispatch_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestNetworkZeroPrefixIsRoute guards the route-vs-peer dispatch
|
||||
// invariant: the backends classify a rule as a peer rule purely by the
|
||||
// absence of a destination (neither prefix nor set). A default route
|
||||
// (0.0.0.0/0 or ::/0) is a valid prefix and must therefore classify as
|
||||
// a route, not collapse into the peer path.
|
||||
func TestNetworkZeroPrefixIsRoute(t *testing.T) {
|
||||
for _, p := range []string{"0.0.0.0/0", "::/0", "10.0.0.0/8"} {
|
||||
n := fwmgr.Network{Prefix: netip.MustParsePrefix(p)}
|
||||
assert.True(t, n.IsPrefix(), "%s must report IsPrefix", p)
|
||||
assert.True(t, n.IsPrefix() || n.IsSet(), "%s must classify as a route", p)
|
||||
}
|
||||
|
||||
// A zero-value Network is the only peer-rule shape.
|
||||
var empty fwmgr.Network
|
||||
assert.False(t, empty.IsPrefix(), "zero Network must not be a prefix")
|
||||
assert.False(t, empty.IsSet(), "zero Network must not be a set")
|
||||
}
|
||||
|
||||
// TestDetermineDestinationAlwaysRoute verifies determineDestination
|
||||
// never yields an empty Network for a valid route rule: every branch
|
||||
// (static prefix, default route, dynamic with/without domains, with and
|
||||
// without a local resolver) produces a destination that classifies as a
|
||||
// route. If this regresses, a route rule would be dispatched down the
|
||||
// peer path, which matches on source only.
|
||||
func TestDetermineDestinationAlwaysRoute(t *testing.T) {
|
||||
v4 := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
|
||||
v6 := []netip.Prefix{netip.MustParsePrefix("2001:db8::/48")}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
rule *mgmProto.RouteFirewallRule
|
||||
resolver bool
|
||||
sources []netip.Prefix
|
||||
}{
|
||||
{"static prefix", &mgmProto.RouteFirewallRule{Destination: "192.168.0.0/16"}, false, v4},
|
||||
{"static default route", &mgmProto.RouteFirewallRule{Destination: "0.0.0.0/0"}, false, v4},
|
||||
{"dynamic with domains + resolver", &mgmProto.RouteFirewallRule{IsDynamic: true, Domains: []string{"example.com"}}, true, v4},
|
||||
{"dynamic no domains + resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v4},
|
||||
{"dynamic no domains + resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v6},
|
||||
{"dynamic + no local resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v4},
|
||||
{"dynamic + no local resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v6},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dest, err := determineDestination(tc.rule, tc.resolver, tc.sources)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, dest.IsPrefix() || dest.IsSet(),
|
||||
"destination must classify as a route, got empty Network")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// countingFirewall wraps a real firewall.Manager and counts filter-rule
|
||||
// add/delete calls so a test can assert how many backing rules the acl
|
||||
// manager actually creates and tears down.
|
||||
type countingFirewall struct {
|
||||
fwmgr.Manager
|
||||
mu sync.Mutex
|
||||
addCalls int
|
||||
dels int
|
||||
ruleIDs map[fwmgr.RuleID]struct{}
|
||||
}
|
||||
|
||||
// distinctRules returns the number of distinct backing rules the
|
||||
// backend produced. Because the backend dedups identical content,
|
||||
// repeated AddFilterRule calls for the same rule resolve to one id.
|
||||
func (f *countingFirewall) distinctRules() int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return len(f.ruleIDs)
|
||||
}
|
||||
|
||||
func (f *countingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) (fwmgr.Rule, error) {
|
||||
rule, err := f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
if err == nil {
|
||||
f.mu.Lock()
|
||||
f.addCalls++
|
||||
if f.ruleIDs == nil {
|
||||
f.ruleIDs = make(map[fwmgr.RuleID]struct{})
|
||||
}
|
||||
if rule != nil {
|
||||
f.ruleIDs[rule.ID()] = struct{}{}
|
||||
}
|
||||
f.mu.Unlock()
|
||||
}
|
||||
return rule, err
|
||||
}
|
||||
|
||||
func (f *countingFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
||||
err := f.Manager.DeleteFilterRule(r)
|
||||
if err == nil {
|
||||
f.mu.Lock()
|
||||
f.dels++
|
||||
delete(f.ruleIDs, r.ID())
|
||||
f.mu.Unlock()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func newCountingACL(t *testing.T) (*DefaultManager, *countingFirewall, func()) {
|
||||
t.Helper()
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
fw := &countingFirewall{Manager: realFW}
|
||||
cleanup := func() {
|
||||
require.NoError(t, realFW.Close(nil))
|
||||
ctrl.Finish()
|
||||
}
|
||||
return NewDefaultManager(fw), fw, cleanup
|
||||
}
|
||||
|
||||
// TestDuplicateContentPoliciesShareOneRule verifies the dedup contract
|
||||
// the backends rely on: two policies that authorize an identical flow
|
||||
// (same selector and sources) collapse to a single backing firewall
|
||||
// rule, and that rule survives until BOTH policies are gone. This is
|
||||
// why the backend can dedup on add without refcounting on delete: the
|
||||
// acl manager's pair key matches the backend's content key, so add and
|
||||
// delete stay balanced per content key across full-state reapplies.
|
||||
func TestDuplicateContentPoliciesShareOneRule(t *testing.T) {
|
||||
acl, fw, cleanup := newCountingACL(t)
|
||||
defer cleanup()
|
||||
|
||||
ruleA := &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
}
|
||||
ruleB := &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-B"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
}
|
||||
|
||||
// Both policies present: identical content collapses to one rule.
|
||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleA, ruleB}, FirewallRulesIsEmpty: false}, false)
|
||||
assert.Equal(t, 1, fw.distinctRules(), "identical-content policies must produce one backing rule")
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs), "one content key, one pair")
|
||||
|
||||
// Drop policy A only: the shared rule is still authorized by B, so
|
||||
// nothing is deleted.
|
||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleB}, FirewallRulesIsEmpty: false}, false)
|
||||
assert.Equal(t, 1, fw.distinctRules(), "no new backing rule on reapply")
|
||||
assert.Equal(t, 0, fw.dels, "rule must survive while any policy still authorizes it")
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||
|
||||
// Drop policy B too: now the content key has no authorizer and the
|
||||
// single backing rule is removed exactly once.
|
||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}, false)
|
||||
assert.Equal(t, 1, fw.dels, "rule removed once when last policy is gone")
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||
}
|
||||
318
client/internal/acl/grouping_test.go
Normal file
318
client/internal/acl/grouping_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
// TestGroupPeerRulesPolicyIDSeparates verifies that two FirewallRules
|
||||
// with identical selectors but different PolicyIDs do NOT get merged
|
||||
// into one group, so each policy's sources merge under its own
|
||||
// attribution id. (Identical-content groups may still dedup to one
|
||||
// backing rule at the backend; see TestDuplicateContentPoliciesShareOneRule.)
|
||||
func TestGroupPeerRulesPolicyIDSeparates(t *testing.T) {
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PolicyID: []byte("policy-B"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "rules with different PolicyIDs must produce separate groups")
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesFamilySeparates verifies that v4 and v6 rules
|
||||
// belonging to the same policy don't merge.
|
||||
func TestGroupPeerRulesFamilySeparates(t *testing.T) {
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "2001:db8::1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "rules of different families must produce separate groups")
|
||||
|
||||
var sawV4, sawV6 bool
|
||||
for _, g := range groups {
|
||||
require.Len(t, g.sources, 1)
|
||||
if g.sources[0].Addr().Is4() {
|
||||
sawV4 = true
|
||||
}
|
||||
if g.sources[0].Addr().Is6() {
|
||||
sawV6 = true
|
||||
}
|
||||
}
|
||||
assert.True(t, sawV4 && sawV6)
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesSplitsMixedFamilySingleRule verifies that a single
|
||||
// FirewallRule carrying both v4 and v6 source prefixes is split into one
|
||||
// group per family. Each backend keys a rule to a single family, so a
|
||||
// group whose sources span families would mismatch the other family's
|
||||
// sources. mgmt normally emits one rule per family; this guards against
|
||||
// a mixed-family rule slipping through.
|
||||
func TestGroupPeerRulesSplitsMixedFamilySingleRule(t *testing.T) {
|
||||
srcs := [][]byte{
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::1")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::2")),
|
||||
}
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
SourcePrefixes: srcs,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "mixed-family sources in one rule must split into two groups")
|
||||
|
||||
for _, g := range groups {
|
||||
require.Len(t, g.sources, 2)
|
||||
v6 := prefixIsV6(g.sources[0])
|
||||
for _, s := range g.sources {
|
||||
assert.Equal(t, v6, prefixIsV6(s), "every source in a group must share one family")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesMergesSameSelector verifies that rules sharing
|
||||
// every distinguishing field (policy, family, direction, action,
|
||||
// proto, port) collapse into a single multi-source group.
|
||||
func TestGroupPeerRulesMergesSameSelector(t *testing.T) {
|
||||
mk := func(peerIP string) *mgmProto.FirewallRule {
|
||||
return &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: peerIP,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
}
|
||||
}
|
||||
rules := []*mgmProto.FirewallRule{mk("10.0.0.1"), mk("10.0.0.2"), mk("10.0.0.3")}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 1)
|
||||
require.Len(t, groups[0].sources, 3)
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesPortSeparates verifies that PortInfo is part of the
|
||||
// selector key: rules differing only in port must not merge, and a
|
||||
// single port must not merge with a range. A regression dropping the
|
||||
// port from the key would collapse rules for different ports into one.
|
||||
func TestGroupPeerRulesPortSeparates(t *testing.T) {
|
||||
mkPort := func(peerIP string, port uint32) *mgmProto.FirewallRule {
|
||||
return &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: peerIP,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Port{Port: port}},
|
||||
}
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules([]*mgmProto.FirewallRule{
|
||||
mkPort("10.0.0.1", 80), mkPort("10.0.0.2", 80), mkPort("10.0.0.3", 443),
|
||||
})
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "rules on different ports must not merge")
|
||||
|
||||
rangeRule := &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Range_{Range: &mgmProto.PortInfo_Range{Start: 80, End: 90}}},
|
||||
}
|
||||
groups, denyErr, err = groupPeerRules([]*mgmProto.FirewallRule{mkPort("10.0.0.1", 80), rangeRule})
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "a single port and a range must not merge")
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesUsesSourcePrefixesWhenPresent verifies that the
|
||||
// new sourcePrefixes wire field is consumed and produces a
|
||||
// multi-source group in one shot (no client-side merging needed).
|
||||
func TestGroupPeerRulesUsesSourcePrefixesWhenPresent(t *testing.T) {
|
||||
srcs := [][]byte{
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.3")),
|
||||
}
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
SourcePrefixes: srcs,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 1)
|
||||
require.Len(t, groups[0].sources, 3)
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesActionSeparates verifies the obvious: accept
|
||||
// and drop rules with the same selector don't merge.
|
||||
func TestGroupPeerRulesActionSeparates(t *testing.T) {
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2)
|
||||
}
|
||||
|
||||
// failingDeleteFirewall wraps a real firewall.Manager and forces the
|
||||
// next N DeleteFilterRule calls to fail. Used to verify that the acl
|
||||
// manager retains rules whose deletion was rejected by the backend,
|
||||
// so they get retried on the next ApplyFiltering pass instead of
|
||||
// becoming orphans.
|
||||
type failingDeleteFirewall struct {
|
||||
fwmgr.Manager
|
||||
failCount int
|
||||
}
|
||||
|
||||
func (f *failingDeleteFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
||||
if f.failCount > 0 {
|
||||
f.failCount--
|
||||
return errors.New("simulated delete failure")
|
||||
}
|
||||
return f.Manager.DeleteFilterRule(r)
|
||||
}
|
||||
|
||||
// TestApplyFilteringRetainsRulesOnDeleteFailure verifies that a
|
||||
// transient DeleteFilterRule error doesn't make the acl manager
|
||||
// forget about a rule. The rule must remain in peerRulesPairs so the
|
||||
// next ApplyFiltering pass attempts the delete again.
|
||||
func TestApplyFilteringRetainsRulesOnDeleteFailure(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, realFW.Close(nil)) }()
|
||||
|
||||
fw := &failingDeleteFirewall{Manager: realFW}
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First pass: install a rule.
|
||||
netmap1 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
acl.ApplyFiltering(netmap1, false)
|
||||
require.Equal(t, 1, len(acl.peerRulesPairs), "rule should be installed")
|
||||
|
||||
// Second pass: remove the rule from the map. The backend will
|
||||
// fail the delete; the acl manager must retain the rule.
|
||||
fw.failCount = 1
|
||||
netmap2 := &mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}
|
||||
acl.ApplyFiltering(netmap2, false)
|
||||
require.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"rule must be retained when DeleteFilterRule fails so it gets retried")
|
||||
|
||||
// Third pass: same map, backend no longer fails. The rule
|
||||
// should now succeed in being removed.
|
||||
acl.ApplyFiltering(netmap2, false)
|
||||
require.Equal(t, 0, len(acl.peerRulesPairs), "retry should succeed")
|
||||
}
|
||||
@@ -5,18 +5,18 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
type RuleID string
|
||||
// RuleID aliases manager.RuleID so existing nbid.RuleID references
|
||||
// keep working while the canonical type lives in the firewall package.
|
||||
type RuleID = manager.RuleID
|
||||
|
||||
func (r RuleID) ID() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func GenerateRouteRuleKey(
|
||||
// GenerateRuleID returns a deterministic content hash identifying a filter rule.
|
||||
func GenerateRuleID(
|
||||
sources []netip.Prefix,
|
||||
destination manager.Network,
|
||||
proto manager.Protocol,
|
||||
@@ -24,6 +24,7 @@ func GenerateRouteRuleKey(
|
||||
dPort *manager.Port,
|
||||
action manager.Action,
|
||||
) RuleID {
|
||||
sources = slices.Clone(sources)
|
||||
manager.SortPrefixes(sources)
|
||||
|
||||
h := sha256.New()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
@@ -23,6 +21,10 @@ import (
|
||||
|
||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||
|
||||
// ErrNoRuleReturned is returned when the firewall backend reports success
|
||||
// from AddFilterRule but yields no rule to track.
|
||||
var ErrNoRuleReturned = errors.New("backend returned no rule")
|
||||
|
||||
// Manager is a ACL rules manager
|
||||
type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||
@@ -31,17 +33,46 @@ type Manager interface {
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||
routeRules map[id.RuleID]struct{}
|
||||
routeRules map[id.RuleID]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// peerRuleGroup collapses a set of single-source FirewallRules sharing
|
||||
// the same selector into one multi-source rule to push to the backend.
|
||||
type peerRuleGroup struct {
|
||||
direction mgmProto.RuleDirection
|
||||
action mgmProto.RuleAction
|
||||
protocol mgmProto.RuleProtocol
|
||||
port *mgmProto.PortInfo
|
||||
// legacyPort is used only when PortInfo is empty (old management).
|
||||
legacyPort string
|
||||
policyID []byte
|
||||
sources []netip.Prefix
|
||||
}
|
||||
|
||||
// peerRuleKey is the comparable selector that decides which single-source
|
||||
// rules merge into one group. Rules with an equal key collapse into one
|
||||
// multi-source backend rule. PortInfo is flattened into its scalar fields
|
||||
// so the key compares by value; policyID keeps policies separate so two
|
||||
// policies authorizing different peers don't merge under one attribution.
|
||||
type peerRuleKey struct {
|
||||
v6 bool
|
||||
policyID string
|
||||
direction mgmProto.RuleDirection
|
||||
action mgmProto.RuleAction
|
||||
protocol mgmProto.RuleProtocol
|
||||
legacyPort string
|
||||
port uint16
|
||||
rangeStart uint16
|
||||
rangeEnd uint16
|
||||
}
|
||||
|
||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
firewall: fm,
|
||||
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
||||
routeRules: make(map[id.RuleID]struct{}),
|
||||
routeRules: make(map[id.RuleID]firewall.Rule),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,10 +99,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
time.Since(start), total)
|
||||
}()
|
||||
|
||||
d.applyPeerACLs(networkMap)
|
||||
if err := d.applyPeerACLs(networkMap); err != nil {
|
||||
log.Errorf("apply peer ACLs: %v", err)
|
||||
}
|
||||
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||
log.Errorf("apply route ACLs: %v", err)
|
||||
}
|
||||
|
||||
if err := d.firewall.Flush(); err != nil {
|
||||
@@ -79,7 +112,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
|
||||
rules := networkMap.FirewallRules
|
||||
|
||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||
@@ -102,59 +135,158 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
)
|
||||
}
|
||||
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
// Group incoming single-source rules from management by their
|
||||
// (direction, action, proto, port) selector and merge sources.
|
||||
// One call to the firewall backend per merged rule.
|
||||
// A deny we cannot decode would leave its traffic unblocked, so skip
|
||||
// the whole pass and keep existing rules until the next sync.
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
if denyErr != nil {
|
||||
return fmt.Errorf("decode deny rule sources: %w", denyErr)
|
||||
}
|
||||
|
||||
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
||||
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
||||
// the missing deny. Currently we accumulate errors and continue.
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
var merr *multierror.Error
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipsetName, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
d.ipsetCounter++
|
||||
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
ipsetByRuleSelectors[selector] = ipsetName
|
||||
}
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// Apply denies first. A deny that fails to install is a security
|
||||
// failure (fail-open), so if any deny errors we roll back the
|
||||
// denies we already installed in this pass and bail out without
|
||||
// installing any accept. Pre-existing rules stay untouched until
|
||||
// the next successful pass clears them.
|
||||
denies, accepts := splitDenyAccept(groups)
|
||||
if err := d.installPeerGroups(denies, newRulePairs, true); err != nil {
|
||||
return fmt.Errorf("install deny rules: %w", err)
|
||||
}
|
||||
|
||||
if err := d.installPeerGroups(accepts, newRulePairs, false); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// Tear down rules that disappeared from the networkmap. Any rule
|
||||
// the backend refuses to delete stays in our tracking so it gets
|
||||
// retried on the next ApplyFiltering. Otherwise a transient
|
||||
// delete failure would leak the rule in the firewall until the
|
||||
// process exits.
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; ok {
|
||||
continue
|
||||
}
|
||||
if len(rulePair) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
}
|
||||
|
||||
if merr != nil {
|
||||
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete peer firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
var remaining []firewall.Rule
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||
log.Errorf("failed to delete peer firewall rule, will retry: %v", err)
|
||||
remaining = append(remaining, rule)
|
||||
}
|
||||
delete(d.peerRulesPairs, pairID)
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
newRulePairs[pairID] = remaining
|
||||
}
|
||||
}
|
||||
d.peerRulesPairs = newRulePairs
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// installPeerGroups applies each group and records the resulting rule
|
||||
// pairs in newRulePairs. With atomic set (deny rules), a single failure
|
||||
// rolls back every rule installed in this call and returns, leaving the
|
||||
// firewall exactly as before: denies are fail-closed and must be applied
|
||||
// all-or-nothing. With atomic unset (accept rules), failures are
|
||||
// accumulated and the remaining groups still install, so one malformed
|
||||
// allow cannot drop every other legitimate allow in the pass.
|
||||
func (d *DefaultManager) installPeerGroups(groups []*peerRuleGroup, newRulePairs map[id.RuleID][]firewall.Rule, atomic bool) error {
|
||||
var freshlyInstalled []id.RuleID
|
||||
var merr *multierror.Error
|
||||
for _, g := range groups {
|
||||
pairID, rulePair, err := d.applyPeerGroup(g)
|
||||
if err != nil {
|
||||
if atomic {
|
||||
d.rollbackInstalled(freshlyInstalled)
|
||||
return fmt.Errorf("apply firewall rule: %w", err)
|
||||
}
|
||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||
continue
|
||||
}
|
||||
if len(rulePair) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, existed := d.peerRulesPairs[pairID]; !existed {
|
||||
freshlyInstalled = append(freshlyInstalled, pairID)
|
||||
}
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollbackInstalled(pairIDs []id.RuleID) {
|
||||
var merr *multierror.Error
|
||||
for _, pairID := range pairIDs {
|
||||
for _, rule := range d.peerRulesPairs[pairID] {
|
||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rule %s: %w", pairID, err))
|
||||
}
|
||||
}
|
||||
delete(d.peerRulesPairs, pairID)
|
||||
}
|
||||
if err := nberrors.FormatErrorOrNil(merr); err != nil {
|
||||
log.Errorf("rollback peer rules: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerGroup(g *peerRuleGroup) (id.RuleID, []firewall.Rule, error) {
|
||||
protocol, err := ConvertToFirewallProtocol(g.protocol)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
||||
}
|
||||
action, err := convertFirewallAction(g.action)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
||||
}
|
||||
port, err := resolveGroupPort(g)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
var fwRule firewall.Rule
|
||||
switch g.direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, nil, port, action)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
if d.firewall.IsStateful() {
|
||||
return "", nil, nil
|
||||
}
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return "", nil, nil
|
||||
}
|
||||
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, port, nil, action)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
if fwRule == nil {
|
||||
return "", nil, fmt.Errorf("add firewall rule: %w", ErrNoRuleReturned)
|
||||
}
|
||||
|
||||
// Derive the pair id from the backend rule, like the route path:
|
||||
// the backend dedups identical content, so two policies authorizing
|
||||
// the same flow resolve to the same id and a single backing rule.
|
||||
return fwRule.ID(), []firewall.Rule{fwRule}, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||
newRouteRules := make(map[id.RuleID]firewall.Rule, len(rules))
|
||||
var merr *multierror.Error
|
||||
|
||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||
// Apply new rules - firewall manager will return the existing rule if already present
|
||||
for _, rule := range rules {
|
||||
id, err := d.applyRouteACL(rule, dynamicResolver)
|
||||
addedRule, err := d.applyRouteACL(rule, dynamicResolver)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
||||
@@ -163,16 +295,18 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
|
||||
}
|
||||
continue
|
||||
}
|
||||
newRouteRules[id] = struct{}{}
|
||||
newRouteRules[addedRule.ID()] = addedRule
|
||||
}
|
||||
|
||||
// Clean up old firewall rules
|
||||
for id := range d.routeRules {
|
||||
if _, exists := newRouteRules[id]; !exists {
|
||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
|
||||
}
|
||||
// implicitly deleted from the map
|
||||
// Tear down old route rules; retain ones the backend refused so a
|
||||
// transient failure doesn't leave orphaned rules in the firewall.
|
||||
for ruleID, rule := range d.routeRules {
|
||||
if _, exists := newRouteRules[ruleID]; exists {
|
||||
continue
|
||||
}
|
||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule, will retry: %w", err))
|
||||
newRouteRules[ruleID] = rule
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,102 +314,196 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (firewall.Rule, error) {
|
||||
if len(rule.SourceRanges) == 0 {
|
||||
return "", ErrSourceRangesEmpty
|
||||
return nil, ErrSourceRangesEmpty
|
||||
}
|
||||
|
||||
var sources []netip.Prefix
|
||||
for _, sourceRange := range rule.SourceRanges {
|
||||
source, err := netip.ParsePrefix(sourceRange)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse source range: %w", err)
|
||||
return nil, fmt.Errorf("parse source range: %w", err)
|
||||
}
|
||||
sources = append(sources, source)
|
||||
sources = append(sources, firewall.UnmapPrefix(source))
|
||||
}
|
||||
|
||||
destination, err := determineDestination(rule, dynamicResolver, sources)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("determine destination: %w", err)
|
||||
return nil, fmt.Errorf("determine destination: %w", err)
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||
protocol, err := ConvertToFirewallProtocol(rule.Protocol)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid protocol: %w", err)
|
||||
return nil, fmt.Errorf("invalid protocol: %w", err)
|
||||
}
|
||||
|
||||
action, err := convertFirewallAction(rule.Action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid action: %w", err)
|
||||
return nil, fmt.Errorf("invalid action: %w", err)
|
||||
}
|
||||
|
||||
dPorts := convertPortInfo(rule.PortInfo)
|
||||
|
||||
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||
addedRule, err := d.firewall.AddFilterRule(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add route rule: %w", err)
|
||||
return nil, fmt.Errorf("add route rule: %w", err)
|
||||
}
|
||||
if addedRule == nil {
|
||||
return nil, fmt.Errorf("add route rule: %w", ErrNoRuleReturned)
|
||||
}
|
||||
|
||||
return id.RuleID(addedRule.ID()), nil
|
||||
return addedRule, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (id.RuleID, []firewall.Rule, error) {
|
||||
ip, err := extractRuleIP(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
// splitDenyAccept partitions groups by action so denies can be
|
||||
// applied before accepts. Order within each bucket is preserved.
|
||||
func splitDenyAccept(groups []*peerRuleGroup) (denies, accepts []*peerRuleGroup) {
|
||||
for _, g := range groups {
|
||||
if g.action == mgmProto.RuleAction_DROP {
|
||||
denies = append(denies, g)
|
||||
} else {
|
||||
accepts = append(accepts, g)
|
||||
}
|
||||
}
|
||||
return denies, accepts
|
||||
}
|
||||
|
||||
// groupPeerRules merges single-source rules sharing a selector into
|
||||
// multi-source groups. It splits source-decode failures by action:
|
||||
// denyErr is non-nil when a deny rule could not be decoded, which is a
|
||||
// fail-open risk the caller must treat as fatal for the pass; err
|
||||
// carries the tolerable accept-rule failures the caller can log and
|
||||
// continue past.
|
||||
func groupPeerRules(rules []*mgmProto.FirewallRule) (groups []*peerRuleGroup, denyErr error, err error) {
|
||||
var denyMerr, acceptMerr *multierror.Error
|
||||
byKey := make(map[peerRuleKey]*peerRuleGroup)
|
||||
order := make([]peerRuleKey, 0)
|
||||
|
||||
for _, r := range rules {
|
||||
srcs, decErr := extractRuleSources(r)
|
||||
if decErr != nil {
|
||||
if r.Action == mgmProto.RuleAction_DROP {
|
||||
denyMerr = multierror.Append(denyMerr, decErr)
|
||||
} else {
|
||||
acceptMerr = multierror.Append(acceptMerr, decErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// A single FirewallRule normally carries one address family, but
|
||||
// split by family defensively: each backend keys a rule to one
|
||||
// family and would mismatch sources of the other, so a group's
|
||||
// sources must never span families.
|
||||
v4, v6 := splitPrefixesByFamily(srcs)
|
||||
for _, sub := range []struct {
|
||||
isV6 bool
|
||||
sources []netip.Prefix
|
||||
}{{false, v4}, {true, v6}} {
|
||||
if len(sub.sources) == 0 {
|
||||
continue
|
||||
}
|
||||
key := ruleGroupKey(r, sub.isV6)
|
||||
g, ok := byKey[key]
|
||||
if !ok {
|
||||
g = &peerRuleGroup{
|
||||
direction: r.Direction,
|
||||
action: r.Action,
|
||||
protocol: r.Protocol,
|
||||
port: r.PortInfo,
|
||||
legacyPort: r.Port,
|
||||
policyID: r.PolicyID,
|
||||
}
|
||||
byKey[key] = g
|
||||
order = append(order, key)
|
||||
}
|
||||
g.sources = append(g.sources, sub.sources...)
|
||||
}
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
out := make([]*peerRuleGroup, 0, len(order))
|
||||
for _, k := range order {
|
||||
out = append(out, byKey[k])
|
||||
}
|
||||
return out, nberrors.FormatErrorOrNil(denyMerr), nberrors.FormatErrorOrNil(acceptMerr)
|
||||
}
|
||||
|
||||
func prefixIsV6(p netip.Prefix) bool {
|
||||
return p.Addr().Is6() && !p.Addr().Is4In6()
|
||||
}
|
||||
|
||||
// splitPrefixesByFamily partitions prefixes into IPv4 and IPv6 groups.
|
||||
func splitPrefixesByFamily(prefixes []netip.Prefix) (v4, v6 []netip.Prefix) {
|
||||
for _, p := range prefixes {
|
||||
if prefixIsV6(p) {
|
||||
v6 = append(v6, p)
|
||||
} else {
|
||||
v4 = append(v4, p)
|
||||
}
|
||||
}
|
||||
return v4, v6
|
||||
}
|
||||
|
||||
// ruleGroupKey builds the selector key for a rule. v6 must reflect the
|
||||
// rule's source family: mgmt emits one rule per family and mixing them
|
||||
// would break ICMP-variant selection in uspfilter.
|
||||
func ruleGroupKey(r *mgmProto.FirewallRule, v6 bool) peerRuleKey {
|
||||
k := peerRuleKey{
|
||||
v6: v6,
|
||||
policyID: string(r.PolicyID),
|
||||
direction: r.Direction,
|
||||
action: r.Action,
|
||||
protocol: r.Protocol,
|
||||
legacyPort: r.Port,
|
||||
}
|
||||
if pi := r.PortInfo; pi != nil {
|
||||
k.port = uint16(pi.GetPort())
|
||||
if rng := pi.GetRange(); rng != nil {
|
||||
k.rangeStart = uint16(rng.GetStart())
|
||||
k.rangeEnd = uint16(rng.GetEnd())
|
||||
}
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
// extractRuleSources returns all source prefixes the rule applies to.
|
||||
// New management populates sourcePrefixes; older management sets PeerIP.
|
||||
func extractRuleSources(r *mgmProto.FirewallRule) ([]netip.Prefix, error) {
|
||||
if len(r.SourcePrefixes) > 0 {
|
||||
out := make([]netip.Prefix, 0, len(r.SourcePrefixes))
|
||||
for _, raw := range r.SourcePrefixes {
|
||||
addr, err := netiputil.DecodeAddr(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode source prefix: %w", err)
|
||||
}
|
||||
out = append(out, netip.PrefixFrom(addr.Unmap(), addr.Unmap().BitLen()))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
action, err := convertFirewallAction(r.Action)
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||
addr, err := netip.ParseAddr(r.PeerIP)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}, nil
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if !portInfoEmpty(r.PortInfo) {
|
||||
port = convertPortInfo(r.PortInfo)
|
||||
} else if r.Port != "" {
|
||||
// old version of management, single port
|
||||
value, err := strconv.Atoi(r.Port)
|
||||
func resolveGroupPort(g *peerRuleGroup) (*firewall.Port, error) {
|
||||
if !portInfoEmpty(g.port) {
|
||||
return convertPortInfo(g.port), nil
|
||||
}
|
||||
if g.legacyPort != "" {
|
||||
value, err := strconv.ParseUint(g.legacyPort, 10, 16)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid port: %w", err)
|
||||
return nil, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
port = &firewall.Port{
|
||||
return &firewall.Port{
|
||||
Values: []uint16{uint16(value)},
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
switch r.Direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
if d.firewall.IsStateful() {
|
||||
return "", nil, nil
|
||||
}
|
||||
// return traffic for outbound connections if firewall is stateless
|
||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return ruleID, rules, nil
|
||||
// nolint:nilnil // a nil port legitimately means "no port restriction"
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
@@ -294,85 +522,9 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
id []byte,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
id []byte,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getPeerRuleID(
|
||||
ip netip.Addr,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
) id.RuleID {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
|
||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||
}
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
}
|
||||
|
||||
|
||||
// extractRuleIP extracts the peer IP from a firewall rule.
|
||||
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
||||
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
||||
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
||||
if len(r.SourcePrefixes) > 0 {
|
||||
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||
addr, err := netip.ParseAddr(r.PeerIP)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
// ConvertToFirewallProtocol maps a management rule protocol to the
|
||||
// firewall protocol type.
|
||||
func ConvertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.RuleProtocol_TCP:
|
||||
return firewall.ProtocolTCP, nil
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
fwmanager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
@@ -76,9 +77,9 @@ func TestDefaultManager(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
existedPairs := map[string]struct{}{}
|
||||
existedPairs := map[fwmanager.RuleID]struct{}{}
|
||||
for id := range acl.peerRulesPairs {
|
||||
existedPairs[id.ID()] = struct{}{}
|
||||
existedPairs[id] = struct{}{}
|
||||
}
|
||||
|
||||
// remove first rule
|
||||
@@ -105,7 +106,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.peerRulesPairs {
|
||||
if _, ok := existedPairs[id.ID()]; ok {
|
||||
if _, ok := existedPairs[id]; ok {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -22,25 +21,6 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// peerLoginExpiredMsg is the exact phrase the management server returns
|
||||
// when a previously SSO-enrolled peer's login has expired. Sourced from
|
||||
// shared/management/status/error.go (NewPeerLoginExpiredError). Matched
|
||||
// by substring so a future server-side rewording that keeps the phrase
|
||||
// still triggers the friendly fallback in Login().
|
||||
const peerLoginExpiredMsg = "peer login has expired"
|
||||
|
||||
// errSetupKeyOnSSOExpiredPeer replaces the raw management error when the
|
||||
// user runs `netbird login -k <setup-key>` against a peer that was
|
||||
// originally enrolled via SSO. Wrapped in a PermissionDenied gRPC status
|
||||
// so callers' existing isPermissionDenied / isAuthError checks still
|
||||
// classify it correctly (early-exit from retry backoff, StatusNeedsLogin
|
||||
// in the server state machine).
|
||||
var errSetupKeyOnSSOExpiredPeer = status.Error(
|
||||
codes.PermissionDenied,
|
||||
"this peer was originally enrolled via SSO and its session has expired. "+
|
||||
"Setup keys can only enrol new peers — run `netbird up` (interactive SSO) to re-login.",
|
||||
)
|
||||
|
||||
// Auth manages authentication operations with the management server
|
||||
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||
type Auth struct {
|
||||
@@ -204,15 +184,6 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
// The peer pub-key is already on file with the management
|
||||
// server (originally enrolled via SSO) and the session has
|
||||
// expired. The setup-key path can only enrol new peers, so
|
||||
// retrying with -k will keep failing. Replace the raw mgm
|
||||
// message with an actionable hint that tells the user to
|
||||
// re-authenticate via SSO instead.
|
||||
if setupKey != "" && jwtToken == "" && isPeerLoginExpired(err) {
|
||||
err = errSetupKeyOnSSOExpiredPeer
|
||||
}
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
@@ -503,16 +474,3 @@ func isLoginNeeded(err error) bool {
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
return isPermissionDenied(err)
|
||||
}
|
||||
|
||||
// isPeerLoginExpired reports whether err is the management server's
|
||||
// "peer login has expired" PermissionDenied response. Used by Login to
|
||||
// detect the case where the caller passed a setup-key but the peer is
|
||||
// actually an SSO-enrolled record whose session needs refreshing — the
|
||||
// setup-key path cannot help there.
|
||||
func isPeerLoginExpired(err error) bool {
|
||||
if !isPermissionDenied(err) {
|
||||
return false
|
||||
}
|
||||
s, _ := status.FromError(err)
|
||||
return strings.Contains(s.Message(), peerLoginExpiredMsg)
|
||||
}
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestIsPeerLoginExpired(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "plain error (not a gRPC status)",
|
||||
err: errors.New("network read: connection reset"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "PermissionDenied with different message",
|
||||
err: status.Error(codes.PermissionDenied, "user is blocked"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated with the expected phrase",
|
||||
// Wrong status code — must still return false.
|
||||
err: status.Error(codes.Unauthenticated, "peer login has expired, please log in once more"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "exact server message",
|
||||
err: status.Error(codes.PermissionDenied, "peer login has expired, please log in once more"),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "phrase as substring",
|
||||
// Future-proofing: if mgm reworords but keeps the phrase,
|
||||
// the friendly fallback must still kick in.
|
||||
err: status.Error(codes.PermissionDenied, "session refused: peer login has expired (account=foo)"),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := isPeerLoginExpired(tc.err); got != tc.want {
|
||||
t.Fatalf("isPeerLoginExpired(%v) = %v, want %v", tc.err, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrSetupKeyOnSSOExpiredPeer(t *testing.T) {
|
||||
// Sentinel must surface as PermissionDenied so the upstream
|
||||
// isPermissionDenied / isAuthError checks classify it correctly
|
||||
// (short-circuit retry backoff, set StatusNeedsLogin).
|
||||
if !isPermissionDenied(errSetupKeyOnSSOExpiredPeer) {
|
||||
t.Fatalf("errSetupKeyOnSSOExpiredPeer must be a PermissionDenied gRPC error")
|
||||
}
|
||||
|
||||
// Message must actually mention SSO and `netbird up` so it is
|
||||
// actionable for the end user. Loose substring checks keep the
|
||||
// test resilient to copy edits.
|
||||
s, _ := status.FromError(errSetupKeyOnSSOExpiredPeer)
|
||||
msg := strings.ToLower(s.Message())
|
||||
for _, want := range []string{"sso", "netbird up"} {
|
||||
if !strings.Contains(msg, want) {
|
||||
t.Errorf("sentinel message should contain %q, got %q", want, s.Message())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PendingFlow stores an in-progress OAuth flow between the RPC that
|
||||
// initiates it (returns the verification URI to the UI) and the RPC
|
||||
// that waits for the user to complete it. The flow handle, the
|
||||
// device-code info, and the absolute expiry are kept together so the
|
||||
// waiting RPC can validate the device code and reuse the same flow.
|
||||
//
|
||||
// PendingFlow is safe for concurrent use; callers must not access the
|
||||
// stored fields directly.
|
||||
type PendingFlow struct {
|
||||
mu sync.Mutex
|
||||
flow OAuthFlow
|
||||
info AuthFlowInfo
|
||||
expiresAt time.Time
|
||||
waitCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewPendingFlow returns an empty PendingFlow ready to be populated by Set.
|
||||
func NewPendingFlow() *PendingFlow {
|
||||
return &PendingFlow{}
|
||||
}
|
||||
|
||||
// Set stores the flow and its authorization info, computing the absolute
|
||||
// expiry from info.ExpiresIn (seconds, as returned by the IdP).
|
||||
func (p *PendingFlow) Set(flow OAuthFlow, info AuthFlowInfo) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.flow = flow
|
||||
p.info = info
|
||||
p.expiresAt = time.Now().Add(time.Duration(info.ExpiresIn) * time.Second)
|
||||
}
|
||||
|
||||
// Get returns the stored flow, info, and whether a flow is currently
|
||||
// pending. Returns (nil, zero, false) after Clear or before Set.
|
||||
func (p *PendingFlow) Get() (OAuthFlow, AuthFlowInfo, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.flow == nil {
|
||||
return nil, AuthFlowInfo{}, false
|
||||
}
|
||||
return p.flow, p.info, true
|
||||
}
|
||||
|
||||
// ExpiresAt returns the absolute expiry of the pending flow. Returns
|
||||
// the zero time when no flow is pending.
|
||||
func (p *PendingFlow) ExpiresAt() time.Time {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.expiresAt
|
||||
}
|
||||
|
||||
// SetWaitCancel records the cancel function for the goroutine currently
|
||||
// blocked in WaitToken so a new RequestAuth can preempt it.
|
||||
func (p *PendingFlow) SetWaitCancel(cancel context.CancelFunc) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.waitCancel = cancel
|
||||
}
|
||||
|
||||
// CancelWait invokes and clears the stored wait-cancel, if any. Safe to
|
||||
// call when no wait is in progress.
|
||||
func (p *PendingFlow) CancelWait() {
|
||||
p.mu.Lock()
|
||||
cancel := p.waitCancel
|
||||
p.waitCancel = nil
|
||||
p.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Clear resets the pending flow to empty. Any stored wait-cancel is
|
||||
// dropped without being invoked — call CancelWait first if the waiting
|
||||
// goroutine must be stopped.
|
||||
func (p *PendingFlow) Clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.flow = nil
|
||||
p.info = AuthFlowInfo{}
|
||||
p.expiresAt = time.Time{}
|
||||
p.waitCancel = nil
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package sessionwatch
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// internal event kinds are no longer exposed: the watcher drives the Sink
|
||||
// directly (NotifyStateChange on deadline change/clear, PublishEvent at
|
||||
// each warning lead). Tests use a mock Sink to observe what the watcher
|
||||
// emits.
|
||||
|
||||
// Metadata keys attached by the daemon to session-warning SystemEvents.
|
||||
// The UI tray reads these to build a locale-aware notification without
|
||||
// relying on the daemon's locale-less UserMessage string, and to
|
||||
// disambiguate the T-WarningLead notification from the T-FinalWarningLead
|
||||
// fallback that auto-opens the SessionAboutToExpire dialog.
|
||||
const (
|
||||
// MetaSessionWarning is set to "true" on both warning events (T-10 and
|
||||
// T-2) so the UI can detect a session-warning SystemEvent without
|
||||
// matching on the message text. Use MetaSessionFinal to distinguish
|
||||
// the two.
|
||||
MetaSessionWarning = "session_warning"
|
||||
// MetaSessionFinal is set to "true" on the T-FinalWarningLead event
|
||||
// only. Consumers that need to auto-open the SessionAboutToExpire
|
||||
// dialog gate on this; T-WarningLead events leave the field unset.
|
||||
MetaSessionFinal = "session_final_warning"
|
||||
// MetaSessionExpiresAt carries the absolute UTC deadline encoded with
|
||||
// FormatExpiresAt; consumers must decode with ParseExpiresAt so a
|
||||
// future format change stays a single edit.
|
||||
MetaSessionExpiresAt = "session_expires_at"
|
||||
// MetaSessionLeadMinutes carries the lead in whole minutes (WarningLead
|
||||
// for the T-10 event, FinalWarningLead for the T-2 event) so the UI
|
||||
// can show "expires in ~N minutes" without hardcoding either constant.
|
||||
MetaSessionLeadMinutes = "lead_minutes"
|
||||
)
|
||||
|
||||
// expiresAtLayout is the wire format used for MetaSessionExpiresAt.
|
||||
// Producer and consumers both go through FormatExpiresAt/ParseExpiresAt
|
||||
// so this layout stays a single source of truth.
|
||||
const expiresAtLayout = time.RFC3339
|
||||
|
||||
// FormatExpiresAt encodes a deadline for MetaSessionExpiresAt. Always
|
||||
// emits UTC so a consumer in another timezone reads the same wall-clock
|
||||
// deadline.
|
||||
func FormatExpiresAt(t time.Time) string {
|
||||
return t.UTC().Format(expiresAtLayout)
|
||||
}
|
||||
|
||||
// ParseExpiresAt decodes the MetaSessionExpiresAt value back to a UTC
|
||||
// time. Returns an error when the field is empty or malformed; the
|
||||
// caller decides whether to fall back (zero value) or propagate.
|
||||
func ParseExpiresAt(s string) (time.Time, error) {
|
||||
t, err := time.Parse(expiresAtLayout, s)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return t.UTC(), nil
|
||||
}
|
||||
|
||||
// FormatLeadMinutes encodes a lead duration for MetaSessionLeadMinutes
|
||||
// as the integer count of whole minutes. Sub-minute residuals are
|
||||
// truncated — the field is informational ("expires in ~N minutes") and
|
||||
// fractional minutes don't change what the UI displays.
|
||||
func FormatLeadMinutes(d time.Duration) string {
|
||||
return strconv.Itoa(int(d / time.Minute))
|
||||
}
|
||||
|
||||
// ParseLeadMinutes decodes a MetaSessionLeadMinutes value. Returns 0
|
||||
// and the parse error for malformed input; consumers that prefer a
|
||||
// silent fallback can simply ignore the error.
|
||||
func ParseLeadMinutes(s string) (int, error) {
|
||||
return strconv.Atoi(s)
|
||||
}
|
||||
@@ -1,387 +0,0 @@
|
||||
// Package sessionwatch tracks the SSO session expiry deadline that the
|
||||
// management server publishes via LoginResponse / SyncResponse and fires
|
||||
// two warning events at fixed lead times before expiry: an interactive
|
||||
// T-WarningLead notification and a dismiss-gated T-FinalWarningLead
|
||||
// fallback dialog.
|
||||
//
|
||||
// The watcher is idempotent: Update may be called as often as the network
|
||||
// map snapshots arrive. Repeating the same deadline is a no-op; a new
|
||||
// deadline reschedules the timers and arms a fresh warning cycle.
|
||||
//
|
||||
// Warning firing is edge-detected. Each unique deadline value fires each
|
||||
// warning callback at most once.
|
||||
package sessionwatch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
// Skew tolerates a small clock difference between the management
|
||||
// server and this peer before treating a deadline as "in the past".
|
||||
// Slightly above typical NTP drift; tight enough that the UI doesn't
|
||||
// paint a stale expiry as if it were valid.
|
||||
Skew = 30 * time.Second
|
||||
|
||||
// maxDeadlineHorizon caps how far in the future an accepted deadline
|
||||
// can sit. A timestamp beyond this is almost certainly a protocol
|
||||
// glitch, and silently arming a 100-year timer would hide the bug.
|
||||
maxDeadlineHorizon = 10 * 365 * 24 * time.Hour
|
||||
|
||||
// WarningLead is how far before expiry the first (interactive)
|
||||
// warning fires. Drives the T-10 OS notification with
|
||||
// Extend/Dismiss actions.
|
||||
WarningLead = 10 * time.Minute
|
||||
|
||||
// FinalWarningLead is how far before expiry the fallback final
|
||||
// warning fires. Drives the auto-opened SessionAboutToExpire dialog,
|
||||
// but only when the user has not dismissed the T-WarningLead warning
|
||||
// for the same deadline. Must be strictly less than WarningLead.
|
||||
FinalWarningLead = 2 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDeadlineBeforeEpoch is returned by Update when the supplied
|
||||
// deadline pre-dates 1970-01-01.
|
||||
ErrDeadlineBeforeEpoch = errors.New("session deadline before unix epoch")
|
||||
|
||||
// ErrDeadlineTooFarFuture is returned by Update when the supplied
|
||||
// deadline is more than maxDeadlineHorizon in the future.
|
||||
ErrDeadlineTooFarFuture = errors.New("session deadline too far in the future")
|
||||
|
||||
// ErrDeadlineInPast is returned by Update when the supplied deadline
|
||||
// is more than Skew in the past.
|
||||
ErrDeadlineInPast = errors.New("session deadline in the past")
|
||||
)
|
||||
|
||||
// StatusRecorder is the side-effect surface the watcher drives on every
|
||||
// state transition. Production wires this to peer.Status (SetSessionExpiresAt
|
||||
// for deadline change/clear, PublishEvent for the two warnings); tests pass
|
||||
// a fake recorder so the same surface is observable without an engine.
|
||||
//
|
||||
// The watcher is the single owner of the deadline propagated to the
|
||||
// recorder: every set, clear, sanity-check rejection and Close routes the
|
||||
// value through SetSessionExpiresAt, so the SubscribeStatus snapshot the UI
|
||||
// reads can never drift from the watcher's timer state. (SetSessionExpiresAt
|
||||
// fans out its own state-change notification, so no separate notify is
|
||||
// needed.) The recorder is server-scoped and outlives this engine-scoped
|
||||
// watcher — without the Close-time clear a teardown (Down, or the Down+Up of
|
||||
// a profile switch) would leave the next session showing the previous one's
|
||||
// stale "expires in" value.
|
||||
//
|
||||
// PublishEvent's signature mirrors peer.Status.PublishEvent: the watcher
|
||||
// composes the metadata internally so the wire format (MetaSession*) is
|
||||
// owned by sessionwatch, not the caller.
|
||||
type StatusRecorder interface {
|
||||
SetSessionExpiresAt(deadline time.Time)
|
||||
PublishEvent(
|
||||
severity cProto.SystemEvent_Severity,
|
||||
category cProto.SystemEvent_Category,
|
||||
message string,
|
||||
userMessage string,
|
||||
metadata map[string]string,
|
||||
)
|
||||
}
|
||||
|
||||
// Watcher observes the latest session deadline and fires two warnings
|
||||
// before it expires: the interactive T-WarningLead notification, and the
|
||||
// fallback T-FinalWarningLead dialog (suppressed when the user dismissed
|
||||
// the first one for the same deadline). Safe for concurrent use.
|
||||
type Watcher struct {
|
||||
lead time.Duration
|
||||
finalLead time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
current time.Time
|
||||
timer *time.Timer
|
||||
finalTimer *time.Timer
|
||||
firedAt time.Time // deadline value the T-WarningLead callback last fired against
|
||||
finalFiredAt time.Time // deadline value the T-FinalWarningLead callback last fired against
|
||||
dismissedAt time.Time // deadline value the user dismissed via Dismiss(); gates fireFinal
|
||||
closed bool
|
||||
recorder StatusRecorder
|
||||
}
|
||||
|
||||
// New returns a watcher with the package defaults WarningLead and
|
||||
// FinalWarningLead. Pass nil for recorder to silence side effects (handy
|
||||
// in unit tests that exercise sanity checks without observing the publish
|
||||
// path).
|
||||
func New(recorder StatusRecorder) *Watcher {
|
||||
return NewWithLeads(WarningLead, FinalWarningLead, recorder)
|
||||
}
|
||||
|
||||
// NewWithLeads returns a watcher with custom lead times. Useful for tests.
|
||||
// final must be strictly less than lead; otherwise both timers fire in the
|
||||
// wrong order or simultaneously and the UI flow breaks. A zero final lead
|
||||
// disables the final-warning timer entirely (see armTimerLocked) so a
|
||||
// millisecond-scale deadline doesn't flush both timers in one tick.
|
||||
func NewWithLeads(lead, final time.Duration, recorder StatusRecorder) *Watcher {
|
||||
return &Watcher{
|
||||
lead: lead,
|
||||
finalLead: final,
|
||||
recorder: recorder,
|
||||
}
|
||||
}
|
||||
|
||||
// Update sets the latest deadline. Pass the zero time to clear (e.g. when
|
||||
// a Sync push from the server omits the field because login expiration
|
||||
// was disabled).
|
||||
//
|
||||
// Same-value updates are no-ops. A different non-zero value cancels any
|
||||
// pending timer, resets the "already fired" guard, and arms a new one.
|
||||
//
|
||||
// Returns one of the sentinel Err* values when the deadline fails the
|
||||
// sanity checks (pre-epoch, far future, or in the past beyond Skew).
|
||||
// In every error case the watcher first clears its state so it stays
|
||||
// consistent with what the caller will push into its other sinks (e.g.
|
||||
// applySessionDeadline forces a zero deadline into the status recorder
|
||||
// after a non-nil error).
|
||||
func (w *Watcher) Update(deadline time.Time) error {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
if deadline.IsZero() {
|
||||
w.clearLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
switch {
|
||||
case deadline.Before(time.Unix(0, 0)):
|
||||
w.clearLocked()
|
||||
return fmt.Errorf("%w: %v", ErrDeadlineBeforeEpoch, deadline)
|
||||
case deadline.After(now.Add(maxDeadlineHorizon)):
|
||||
w.clearLocked()
|
||||
return fmt.Errorf("%w: %v", ErrDeadlineTooFarFuture, deadline)
|
||||
case deadline.Before(now.Add(-Skew)):
|
||||
w.clearLocked()
|
||||
return fmt.Errorf("%w: %v (now=%v)", ErrDeadlineInPast, deadline, now)
|
||||
}
|
||||
|
||||
if deadline.Equal(w.current) {
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
w.stopTimerLocked()
|
||||
w.current = deadline
|
||||
// Reset every per-deadline guard so a refreshed deadline arms a fresh
|
||||
// warning cycle: both edge triggers and the user Dismiss decision
|
||||
// (the user agreed to the old deadline expiring; a new deadline
|
||||
// restarts the contract).
|
||||
w.firedAt = time.Time{}
|
||||
w.finalFiredAt = time.Time{}
|
||||
w.dismissedAt = time.Time{}
|
||||
|
||||
w.armTimerLocked(deadline)
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder != nil {
|
||||
recorder.SetSessionExpiresAt(deadline)
|
||||
}
|
||||
log.Infof("auth session deadline set to: %s (in %s)", deadline.Format(time.RFC3339), time.Until(deadline).Round(time.Second))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deadline returns the most recently observed deadline. Zero when no
|
||||
// deadline is currently tracked.
|
||||
func (w *Watcher) Deadline() time.Time {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.current
|
||||
}
|
||||
|
||||
// Dismiss records the user's "Dismiss" action against the current deadline
|
||||
// and suppresses the upcoming final-warning callback for that deadline.
|
||||
// Idempotent: repeated calls are no-ops. A subsequent Update with a fresh
|
||||
// deadline resets the dismissal so the final-warning cycle re-arms.
|
||||
//
|
||||
// No-op when the watcher holds no deadline or has been closed.
|
||||
func (w *Watcher) Dismiss() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.closed || w.current.IsZero() {
|
||||
return
|
||||
}
|
||||
if w.dismissedAt.Equal(w.current) {
|
||||
return
|
||||
}
|
||||
w.dismissedAt = w.current
|
||||
// Cancel the armed final-warning timer eagerly. fireFinal would also
|
||||
// gate on dismissedAt, but stopping the timer avoids a wakeup with
|
||||
// nothing to do and makes the intent visible.
|
||||
if w.finalTimer != nil {
|
||||
w.finalTimer.Stop()
|
||||
w.finalTimer = nil
|
||||
}
|
||||
log.Infof("auth session final-warning dismissed for deadline %s", w.current.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
// Close stops any pending timer and drops the deadline on the status
|
||||
// recorder. Update calls after Close are ignored. Clearing the recorder
|
||||
// here is what keeps a teardown (Down, or the Down+Up of a profile switch)
|
||||
// from leaving the next session showing this one's stale "expires in"
|
||||
// value — the recorder is server-scoped and outlives this engine-scoped
|
||||
// watcher, so nothing else drops the anchor on teardown.
|
||||
func (w *Watcher) Close() {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.closed = true
|
||||
w.stopTimerLocked()
|
||||
hadDeadline := !w.current.IsZero()
|
||||
w.current = time.Time{}
|
||||
w.firedAt = time.Time{}
|
||||
w.finalFiredAt = time.Time{}
|
||||
w.dismissedAt = time.Time{}
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder != nil && hadDeadline {
|
||||
recorder.SetSessionExpiresAt(time.Time{})
|
||||
}
|
||||
}
|
||||
|
||||
// clearLocked drops the tracked deadline and notifies the recorder so
|
||||
// downstream consumers (SubscribeStatus stream, UI) drop their anchor.
|
||||
// The caller must hold w.mu; this helper releases it before invoking
|
||||
// the recorder.
|
||||
func (w *Watcher) clearLocked() {
|
||||
if w.current.IsZero() {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.stopTimerLocked()
|
||||
w.current = time.Time{}
|
||||
w.firedAt = time.Time{}
|
||||
w.finalFiredAt = time.Time{}
|
||||
w.dismissedAt = time.Time{}
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder != nil {
|
||||
recorder.SetSessionExpiresAt(time.Time{})
|
||||
}
|
||||
log.Infof("auth session deadline cleared")
|
||||
}
|
||||
|
||||
func (w *Watcher) stopTimerLocked() {
|
||||
if w.timer != nil {
|
||||
w.timer.Stop()
|
||||
w.timer = nil
|
||||
}
|
||||
if w.finalTimer != nil {
|
||||
w.finalTimer.Stop()
|
||||
w.finalTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) armTimerLocked(deadline time.Time) {
|
||||
w.timer = armOneShotLocked(deadline.Add(-w.lead), func() { w.fire(deadline) })
|
||||
// finalLead <= 0 disables the final-warning timer entirely. Used by
|
||||
// tests that predate the final-warning fallback so a millisecond-scale
|
||||
// deadline does not flush both timers at once.
|
||||
if w.finalLead > 0 {
|
||||
w.finalTimer = armOneShotLocked(deadline.Add(-w.finalLead), func() { w.fireFinal(deadline) })
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) fire(armedFor time.Time) {
|
||||
w.mu.Lock()
|
||||
if w.closed || !w.current.Equal(armedFor) {
|
||||
// Deadline moved while we were waiting (e.g. a successful extend).
|
||||
// The reschedule path armed a fresh timer; this one is stale.
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if !w.firedAt.IsZero() && w.firedAt.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.firedAt = armedFor
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder == nil {
|
||||
return
|
||||
}
|
||||
log.Infof("auth session expiry soon warning fired")
|
||||
publishWarning(recorder, armedFor, false)
|
||||
}
|
||||
|
||||
// fireFinal mirrors fire for the T-FinalWarningLead timer with an extra
|
||||
// dismiss-gate: if the user dismissed the T-WarningLead notification for
|
||||
// this deadline, the final warning is suppressed entirely.
|
||||
func (w *Watcher) fireFinal(armedFor time.Time) {
|
||||
w.mu.Lock()
|
||||
if w.closed || !w.current.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if !w.finalFiredAt.IsZero() && w.finalFiredAt.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if w.dismissedAt.Equal(armedFor) {
|
||||
w.mu.Unlock()
|
||||
log.Infof("auth session final-warning skipped (dismissed by user)")
|
||||
return
|
||||
}
|
||||
w.finalFiredAt = armedFor
|
||||
recorder := w.recorder
|
||||
w.mu.Unlock()
|
||||
if recorder == nil {
|
||||
return
|
||||
}
|
||||
log.Infof("auth session final-warning fired")
|
||||
publishWarning(recorder, armedFor, true)
|
||||
}
|
||||
|
||||
// armOneShotLocked schedules cb at fireAt. When fireAt is already in the
|
||||
// past it dispatches on the next scheduler tick so a state-change recorder
|
||||
// notification (invoked after w.mu is released) lands first. Caller must
|
||||
// hold w.mu.
|
||||
func armOneShotLocked(fireAt time.Time, cb func()) *time.Timer {
|
||||
delay := time.Until(fireAt)
|
||||
if delay <= 0 {
|
||||
return time.AfterFunc(0, cb)
|
||||
}
|
||||
return time.AfterFunc(delay, cb)
|
||||
}
|
||||
|
||||
// publishWarning composes the SystemEvent for a watcher-fired warning and
|
||||
// pushes it through the recorder. Severity is CRITICAL on both — bypassing
|
||||
// the user's Notifications toggle is deliberate: missing the warning
|
||||
// window forces the post-mortem SessionExpired flow (tunnel torn down,
|
||||
// lock icon, manual re-login), which is the UX we are trying to avoid.
|
||||
func publishWarning(recorder StatusRecorder, deadline time.Time, final bool) {
|
||||
lead := WarningLead
|
||||
message := "session expiry warning"
|
||||
meta := map[string]string{
|
||||
MetaSessionWarning: "true",
|
||||
MetaSessionExpiresAt: FormatExpiresAt(deadline),
|
||||
}
|
||||
if final {
|
||||
lead = FinalWarningLead
|
||||
message = "session expiry final warning"
|
||||
meta[MetaSessionFinal] = "true"
|
||||
}
|
||||
meta[MetaSessionLeadMinutes] = FormatLeadMinutes(lead)
|
||||
|
||||
recorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_AUTHENTICATION,
|
||||
message,
|
||||
"",
|
||||
meta,
|
||||
)
|
||||
}
|
||||
@@ -1,519 +0,0 @@
|
||||
package sessionwatch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// fakeRecorder satisfies StatusRecorder and records every call so tests
|
||||
// can observe what the watcher emits. SetSessionExpiresAt and PublishEvent
|
||||
// land in the same ordered events slice (with the Kind distinguishing
|
||||
// them) so tests that care about ordering still work. lastDeadline holds
|
||||
// the most recent value passed to SetSessionExpiresAt so tests can assert
|
||||
// the recorder ended up cleared/set as expected.
|
||||
type fakeRecorder struct {
|
||||
mu sync.Mutex
|
||||
events []event
|
||||
lastDeadline time.Time
|
||||
}
|
||||
|
||||
type eventKind int
|
||||
|
||||
const (
|
||||
stateChange eventKind = iota
|
||||
publish
|
||||
)
|
||||
|
||||
type event struct {
|
||||
kind eventKind
|
||||
// Set only for publish events.
|
||||
severity cProto.SystemEvent_Severity
|
||||
category cProto.SystemEvent_Category
|
||||
message string
|
||||
meta map[string]string
|
||||
}
|
||||
|
||||
// SetSessionExpiresAt mirrors peer.Status: a same-value write is a no-op,
|
||||
// a real change records the new value and fans out a state-change (the
|
||||
// production recorder calls notifyStateChange internally). The baseline
|
||||
// is the zero time, so an initial clear before any deadline is set emits
|
||||
// nothing — matching the real recorder.
|
||||
func (r *fakeRecorder) SetSessionExpiresAt(deadline time.Time) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.lastDeadline.Equal(deadline) {
|
||||
return
|
||||
}
|
||||
r.lastDeadline = deadline
|
||||
r.events = append(r.events, event{kind: stateChange})
|
||||
}
|
||||
|
||||
func (r *fakeRecorder) deadline() time.Time {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.lastDeadline
|
||||
}
|
||||
|
||||
func (r *fakeRecorder) PublishEvent(
|
||||
severity cProto.SystemEvent_Severity,
|
||||
category cProto.SystemEvent_Category,
|
||||
message string,
|
||||
_ string,
|
||||
metadata map[string]string,
|
||||
) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.events = append(r.events, event{
|
||||
kind: publish,
|
||||
severity: severity,
|
||||
category: category,
|
||||
message: message,
|
||||
meta: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *fakeRecorder) snapshot() []event {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([]event, len(r.events))
|
||||
copy(out, r.events)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e event) isFinalWarning() bool {
|
||||
return e.kind == publish && e.meta[MetaSessionFinal] == "true"
|
||||
}
|
||||
|
||||
func (e event) isWarning() bool {
|
||||
return e.kind == publish && e.meta[MetaSessionWarning] == "true" && e.meta[MetaSessionFinal] != "true"
|
||||
}
|
||||
|
||||
func countWhere(events []event, pred func(event) bool) int {
|
||||
n := 0
|
||||
for _, e := range events {
|
||||
if pred(e) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func waitForEvents(t *testing.T, r *fakeRecorder, want int) []event {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if got := r.snapshot(); len(got) >= want {
|
||||
return got
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
got := r.snapshot()
|
||||
t.Fatalf("timed out waiting for %d events, got %d: %+v", want, len(got), got)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newWatcher builds a watcher with the final timer disabled (finalLead=0),
|
||||
// matching the lead-only behaviour the pre-final-warning tests assume.
|
||||
func newWatcher(lead time.Duration, r *fakeRecorder) *Watcher {
|
||||
return NewWithLeads(lead, 0, r)
|
||||
}
|
||||
|
||||
func TestUpdateZeroBeforeAnythingIsNoop(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
_ = w.Update(time.Time{})
|
||||
|
||||
if got := r.snapshot(); len(got) != 0 {
|
||||
t.Fatalf("expected no events on initial zero, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNonZeroFiresStateChange(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(time.Hour)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 1)
|
||||
if events[0].kind != stateChange {
|
||||
t.Fatalf("expected stateChange, got %+v", events[0])
|
||||
}
|
||||
if !w.Deadline().Equal(d) {
|
||||
t.Fatalf("deadline mismatch: %v vs %v", w.Deadline(), d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSameDeadlineIsNoop(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(time.Hour)
|
||||
_ = w.Update(d)
|
||||
_ = w.Update(d)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 1)
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected exactly 1 event for repeated same deadline, got %d: %+v", len(events), events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarningFiresOnceWithinLeadWindow(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
lead := 50 * time.Millisecond
|
||||
w := newWatcher(lead, r)
|
||||
defer w.Close()
|
||||
|
||||
// Deadline 80ms out — warning should fire after ~30ms.
|
||||
d := time.Now().Add(80 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 2)
|
||||
if events[0].kind != stateChange {
|
||||
t.Fatalf("event[0] should be stateChange, got %+v", events[0])
|
||||
}
|
||||
if !events[1].isWarning() {
|
||||
t.Fatalf("event[1] should be a warning publish, got %+v", events[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarningFiresImmediatelyWhenAlreadyInsideWindow(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r) // lead > delta => fire immediately
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(10 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
events := waitForEvents(t, r, 2)
|
||||
if !events[1].isWarning() {
|
||||
t.Fatalf("expected immediate warning publish, got %+v", events[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDeadlineCancelsPriorTimer(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
lead := 50 * time.Millisecond
|
||||
w := newWatcher(lead, r)
|
||||
defer w.Close()
|
||||
|
||||
first := time.Now().Add(80 * time.Millisecond) // would fire warning ~30ms in
|
||||
_ = w.Update(first)
|
||||
|
||||
// Replace with a far-future deadline before the warning fires.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
second := time.Now().Add(time.Hour)
|
||||
_ = w.Update(second)
|
||||
|
||||
// Wait past when first's warning would have fired.
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
|
||||
if n := countWhere(r.snapshot(), event.isWarning); n != 0 {
|
||||
t.Fatalf("warning fired for cancelled deadline: %+v", r.snapshot())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAfterFireArmsNewWarning(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
lead := 30 * time.Millisecond
|
||||
w := newWatcher(lead, r)
|
||||
defer w.Close()
|
||||
|
||||
first := time.Now().Add(50 * time.Millisecond)
|
||||
_ = w.Update(first)
|
||||
|
||||
// Wait for stateChange + warning of the first cycle.
|
||||
waitForEvents(t, r, 2)
|
||||
|
||||
// Simulate a successful extend: brand new deadline.
|
||||
second := time.Now().Add(60 * time.Millisecond)
|
||||
_ = w.Update(second)
|
||||
|
||||
// 4 events total: stateChange, warning (first), stateChange, warning (second).
|
||||
events := waitForEvents(t, r, 4)
|
||||
if events[2].kind != stateChange {
|
||||
t.Fatalf("event[2] should be stateChange for the new deadline, got %+v", events[2])
|
||||
}
|
||||
if !events[3].isWarning() {
|
||||
t.Fatalf("event[3] should be a warning publish for the new deadline, got %+v", events[3])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateZeroAfterNonZeroClearsState(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(2 * time.Hour)
|
||||
_ = w.Update(d)
|
||||
waitForEvents(t, r, 1)
|
||||
|
||||
_ = w.Update(time.Time{})
|
||||
|
||||
events := waitForEvents(t, r, 2)
|
||||
if events[1].kind != stateChange {
|
||||
t.Fatalf("expected stateChange on clear, got %+v", events[1])
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("Deadline should be zero after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRejectsBeforeEpoch(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
good := time.Now().Add(time.Hour)
|
||||
if err := w.Update(good); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
|
||||
err := w.Update(time.Unix(-100, 0))
|
||||
if !errors.Is(err, ErrDeadlineBeforeEpoch) {
|
||||
t.Fatalf("want ErrDeadlineBeforeEpoch, got %v", err)
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("rejected pre-epoch update must clear deadline; got %v", w.Deadline())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRejectsTooFarFuture(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
good := time.Now().Add(time.Hour)
|
||||
if err := w.Update(good); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
|
||||
err := w.Update(time.Now().Add(50 * 365 * 24 * time.Hour))
|
||||
if !errors.Is(err, ErrDeadlineTooFarFuture) {
|
||||
t.Fatalf("want ErrDeadlineTooFarFuture, got %v", err)
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("rejected far-future update must clear deadline; got %v", w.Deadline())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInPastClearsDeadline(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
good := time.Now().Add(time.Hour)
|
||||
if err := w.Update(good); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
// Drain the stateChange from the seed.
|
||||
waitForEvents(t, r, 1)
|
||||
|
||||
err := w.Update(time.Now().Add(-1 * time.Hour))
|
||||
if !errors.Is(err, ErrDeadlineInPast) {
|
||||
t.Fatalf("want ErrDeadlineInPast, got %v", err)
|
||||
}
|
||||
if !w.Deadline().IsZero() {
|
||||
t.Fatalf("in-past update must clear the deadline, got %v", w.Deadline())
|
||||
}
|
||||
events := waitForEvents(t, r, 2)
|
||||
if events[1].kind != stateChange {
|
||||
t.Fatalf("expected stateChange on clear, got %+v", events[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWithinSkewAccepted(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
// 5 seconds in the past is within the 30s Skew tolerance — accept it.
|
||||
d := time.Now().Add(-5 * time.Second)
|
||||
if err := w.Update(d); err != nil {
|
||||
t.Fatalf("within-skew Update should succeed, got %v", err)
|
||||
}
|
||||
if !w.Deadline().Equal(d) {
|
||||
t.Fatalf("expected deadline to be applied, got %v want %v", w.Deadline(), d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseSilencesUpdates(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(50*time.Millisecond, r)
|
||||
w.Close()
|
||||
|
||||
_ = w.Update(time.Now().Add(time.Hour))
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if got := r.snapshot(); len(got) != 0 {
|
||||
t.Fatalf("expected no events after Close, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCloseClearsRecorderDeadline pins the profile-switch fix: a watcher
|
||||
// holding a live deadline must zero the recorder on Close so the next
|
||||
// engine's watcher (and the UI reading the shared server-scoped recorder)
|
||||
// doesn't start out showing the previous session's stale "expires in".
|
||||
func TestCloseClearsRecorderDeadline(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r)
|
||||
|
||||
d := time.Now().Add(2 * time.Hour)
|
||||
if err := w.Update(d); err != nil {
|
||||
t.Fatalf("seed Update: %v", err)
|
||||
}
|
||||
if got := r.deadline(); !got.Equal(d) {
|
||||
t.Fatalf("recorder deadline after Update = %v, want %v", got, d)
|
||||
}
|
||||
|
||||
w.Close()
|
||||
|
||||
if got := r.deadline(); !got.IsZero() {
|
||||
t.Fatalf("recorder deadline after Close = %v, want zero", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCloseWithoutDeadlineLeavesRecorderUntouched guards the symmetric
|
||||
// case: closing a watcher that never held a deadline must not emit a
|
||||
// redundant clear (the recorder may legitimately hold a value written by
|
||||
// some other path; the watcher only owns what it set).
|
||||
func TestCloseWithoutDeadlineLeavesRecorderUntouched(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := newWatcher(time.Hour, r)
|
||||
|
||||
w.Close()
|
||||
|
||||
if got := r.snapshot(); len(got) != 0 {
|
||||
t.Fatalf("expected no events from Close on an empty watcher, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalWarningFiresAfterRegularWarning(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
// Warning fires at deadline-80ms, final at deadline-30ms.
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
// Expect stateChange + warning + final-warning.
|
||||
events := waitForEvents(t, r, 3)
|
||||
|
||||
if countWhere(events, func(e event) bool { return e.kind == stateChange }) != 1 {
|
||||
t.Fatalf("expected exactly 1 stateChange, got %+v", events)
|
||||
}
|
||||
if countWhere(events, event.isWarning) != 1 {
|
||||
t.Fatalf("expected exactly 1 warning publish, got %+v", events)
|
||||
}
|
||||
if countWhere(events, event.isFinalWarning) != 1 {
|
||||
t.Fatalf("expected exactly 1 final-warning publish, got %+v", events)
|
||||
}
|
||||
|
||||
// Warning must precede final (same deadline, longer lead fires first).
|
||||
var wIdx, fIdx int
|
||||
for i, e := range events {
|
||||
switch {
|
||||
case e.isWarning():
|
||||
wIdx = i
|
||||
case e.isFinalWarning():
|
||||
fIdx = i
|
||||
}
|
||||
}
|
||||
if wIdx > fIdx {
|
||||
t.Fatalf("warning must publish before final-warning, got order %+v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDismissSuppressesFinalWarning(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
d := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
// Wait for the warning publish so we know we're inside the warning
|
||||
// window, then dismiss before the final timer would fire.
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if countWhere(r.snapshot(), event.isWarning) >= 1 {
|
||||
break
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
if countWhere(r.snapshot(), event.isWarning) < 1 {
|
||||
t.Fatalf("warning did not publish in time, events=%+v", r.snapshot())
|
||||
}
|
||||
|
||||
w.Dismiss()
|
||||
|
||||
// Now wait past when the final would have fired.
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
if n := countWhere(r.snapshot(), event.isFinalWarning); n != 0 {
|
||||
t.Fatalf("final-warning published after Dismiss(), events=%+v", r.snapshot())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDismissResetByNewDeadline(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
first := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(first)
|
||||
|
||||
// Dismiss against the first deadline.
|
||||
w.Dismiss()
|
||||
|
||||
// Replace with a fresh deadline before the first's timers complete.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
second := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(second)
|
||||
|
||||
// The second cycle must publish a final-warning (the dismiss state
|
||||
// did not carry over).
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if countWhere(r.snapshot(), event.isFinalWarning) >= 1 {
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
if countWhere(r.snapshot(), event.isFinalWarning) < 1 {
|
||||
t.Fatalf("final-warning did not publish on fresh deadline after Dismiss reset, events=%+v", r.snapshot())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDismissBeforeUpdateIsNoop(t *testing.T) {
|
||||
r := &fakeRecorder{}
|
||||
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
|
||||
defer w.Close()
|
||||
|
||||
// No deadline tracked yet; Dismiss must be a no-op (no panic, no state).
|
||||
w.Dismiss()
|
||||
|
||||
d := time.Now().Add(100 * time.Millisecond)
|
||||
_ = w.Update(d)
|
||||
|
||||
// Final warning should still publish — Dismiss only acts on the current
|
||||
// deadline, and there was none at the time of the call.
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if countWhere(r.snapshot(), event.isFinalWarning) >= 1 {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("final-warning did not publish after no-op pre-Update Dismiss, events=%+v", r.snapshot())
|
||||
}
|
||||
@@ -256,15 +256,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
// On daemon shutdown / Down() the parent context is cancelled
|
||||
// and the dial fails with "context canceled". Wrapping that
|
||||
// into state would leave the snapshot stuck at Connecting+err
|
||||
// until the backoff loop wakes up — instead let the operation
|
||||
// return cleanly so the deferred state.Set(StatusIdle) takes
|
||||
// effect on the next iteration.
|
||||
if c.ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
||||
@@ -393,10 +384,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
// Seed the session-expiry deadline from the LoginResponse. Subsequent
|
||||
// changes flow in through SyncResponse and are applied in handleSync.
|
||||
engine.ApplySessionDeadline(loginResp.GetSessionExpiresAt())
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
@@ -437,11 +424,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
c.statusRecorder.ClientStart()
|
||||
// Wrap the backoff with c.ctx so Down()/actCancel propagates into the
|
||||
// inter-attempt sleep — otherwise a 15s MaxInterval can keep the retry
|
||||
// loop alive long after the caller asked to give up, leaving the
|
||||
// status stream stuck at Connecting.
|
||||
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||
err = backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
|
||||
@@ -900,7 +900,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
|
||||
pf, err := uspfilter.Create(wgIface, nil, false, flowLogger, iface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -3,7 +3,6 @@ package dnsfwd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -160,12 +159,13 @@ func (m *Manager) allowDNSFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||
anyV4 := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
dnsRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add udp firewall rule: %w", err)
|
||||
}
|
||||
|
||||
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
||||
tcpRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add tcp firewall rule: %w", err)
|
||||
}
|
||||
@@ -174,8 +174,12 @@ func (m *Manager) allowDNSFirewall() error {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
m.fwRules = dnsRules
|
||||
m.tcpRules = tcpRules
|
||||
if dnsRule != nil {
|
||||
m.fwRules = []firewall.Rule{dnsRule}
|
||||
}
|
||||
if tcpRule != nil {
|
||||
m.tcpRules = []firewall.Rule{tcpRule}
|
||||
}
|
||||
|
||||
m.registerNetstackServices()
|
||||
|
||||
@@ -209,12 +213,12 @@ func (m *Manager) unregisterNetstackServices() {
|
||||
func (m *Manager) dropDNSFirewall() error {
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range m.fwRules {
|
||||
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
||||
if err := m.firewall.DeleteFilterRule(rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||
}
|
||||
}
|
||||
for _, rule := range m.tcpRules {
|
||||
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
||||
if err := m.firewall.DeleteFilterRule(rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,20 +250,6 @@ type Engine struct {
|
||||
jobExecutorWG sync.WaitGroup
|
||||
|
||||
exposeManager *expose.Manager
|
||||
|
||||
sessionWatcher sessionDeadlineWatcher
|
||||
}
|
||||
|
||||
// sessionDeadlineWatcher is the engine-facing surface of the SSO session
|
||||
// expiry watcher. The concrete implementation (sessionwatch.Watcher) is wired
|
||||
// in via newSessionWatcher, which is build-tagged so the js/wasm build links a
|
||||
// no-op stub instead of pulling the full sessionwatch package (and its timer
|
||||
// machinery) into the binary — the wasm client never runs the engine's
|
||||
// session-warning flow.
|
||||
type sessionDeadlineWatcher interface {
|
||||
Update(deadline time.Time) error
|
||||
Dismiss()
|
||||
Close()
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -307,17 +293,6 @@ func NewEngine(
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
}
|
||||
// sessionWatcher keeps the SubscribeStatus consumers in sync with the
|
||||
// session expiry deadline. Deadline-change ticks come for free via
|
||||
// Status.SetSessionExpiresAt; the watcher exists to push a wake-up at
|
||||
// T-WarningLead and T-FinalWarningLead so the UI repaints the remaining
|
||||
// time / warning state even when nothing else changed, and to publish
|
||||
// two SystemEvents (the warning composition lives in sessionwatch so
|
||||
// the wire format stays owned by one package):
|
||||
// - T-WarningLead → interactive "Extend now / Dismiss" notification
|
||||
// - T-FinalWarningLead → auto-opened SessionAboutToExpire dialog,
|
||||
// suppressed when the user dismissed the earlier warning
|
||||
engine.sessionWatcher = newSessionWatcher(engine.statusRecorder)
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
return engine
|
||||
@@ -358,10 +333,6 @@ func (e *Engine) Stop() error {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
|
||||
if e.sessionWatcher != nil {
|
||||
e.sessionWatcher.Close()
|
||||
}
|
||||
|
||||
if e.updateManager != nil {
|
||||
e.updateManager.SetDownloadOnly()
|
||||
}
|
||||
@@ -669,14 +640,14 @@ func (e *Engine) initFirewall() error {
|
||||
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||
|
||||
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
|
||||
if _, err := e.firewall.AddPeerFiltering(
|
||||
if _, err := e.firewall.AddFilterRule(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
firewallManager.Network{},
|
||||
firewallManager.ProtocolUDP,
|
||||
nil,
|
||||
&port,
|
||||
firewallManager.ActionAccept,
|
||||
"",
|
||||
); err != nil {
|
||||
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||
return nil
|
||||
@@ -726,7 +697,7 @@ func (e *Engine) blockLanAccess() {
|
||||
if network.Addr().Is6() {
|
||||
source = v6
|
||||
}
|
||||
if _, err := e.firewall.AddRouteFiltering(
|
||||
if _, err := e.firewall.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{source},
|
||||
firewallManager.Network{Prefix: network},
|
||||
@@ -894,8 +865,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
e.ApplySessionDeadline(update.GetSessionExpiresAt())
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
@@ -1173,7 +1142,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
||||
TempDir: e.config.TempDir,
|
||||
ClientMetrics: e.clientMetrics,
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(e.ctx, true)
|
||||
e.RunHealthProbes(true)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2058,20 +2027,7 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
//
|
||||
// ctx scopes the (potentially slow) STUN/TURN probing: a caller that gives up —
|
||||
// e.g. a Status RPC whose client disconnected — cancels its ctx and the probe
|
||||
// returns instead of running to its per-component timeout. The engine's own
|
||||
// lifetime ctx still applies independently, so an engine shutdown aborts the
|
||||
// probe even if the caller's ctx is context.Background().
|
||||
func (e *Engine) RunHealthProbes(ctx context.Context, waitForResult bool) bool {
|
||||
// Tie the caller's ctx to the engine lifetime: either cancelling aborts
|
||||
// the probe below.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
stop := context.AfterFunc(e.ctx, cancel)
|
||||
defer stop()
|
||||
|
||||
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
signalHealthy := e.signal.IsHealthy()
|
||||
@@ -2094,9 +2050,9 @@ func (e *Engine) RunHealthProbes(ctx context.Context, waitForResult bool) bool {
|
||||
if runtime.GOOS != "js" {
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(ctx, stuns, turns)
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(ctx, stuns, turns)
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
@@ -2413,7 +2369,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal
|
||||
var merr *multierror.Error
|
||||
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
proto, err := convertToFirewallProtocol(rule.GetProtocol())
|
||||
proto, err := acl.ConvertToFirewallProtocol(rule.GetProtocol())
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
|
||||
continue
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
// ApplySessionDeadline propagates the absolute SSO session deadline carried on
|
||||
// LoginResponse / SyncResponse to both the watcher (for the edge-triggered
|
||||
// warning) and the status recorder (for the SubscribeStatus / Status RPC
|
||||
// snapshot the UI consumes).
|
||||
//
|
||||
// The wire field is 3-state:
|
||||
// - nil → snapshot carries no info; keep the
|
||||
// previously-anchored deadline (no-op)
|
||||
// - explicit zero (s=0, n=0) → peer is not SSO-registered or expiry is
|
||||
// disabled; clear both sinks
|
||||
// - valid timestamp → new deadline; arm watcher, expose on
|
||||
// status recorder
|
||||
//
|
||||
// Deadline sanity-checks live in sessionwatch.Watcher.Update. Any rejected
|
||||
// value is treated as a clear on both sinks: the alternative — leaving the
|
||||
// previously-known deadline in place — risks the UI confidently displaying
|
||||
// a stale "expires in X" while the server has actually invalidated it.
|
||||
func (e *Engine) ApplySessionDeadline(ts *timestamppb.Timestamp) {
|
||||
if ts == nil {
|
||||
return
|
||||
}
|
||||
var deadline time.Time
|
||||
// Explicit zero (seconds=0 AND nanos=0) is the sentinel for "disabled".
|
||||
// Everything else flows through Watcher.Update, whose sanity-checks
|
||||
// reject out-of-range / pre-epoch / far-future / too-stale values and
|
||||
// clear on rejection.
|
||||
if ts.GetSeconds() != 0 || ts.GetNanos() != 0 {
|
||||
deadline = ts.AsTime().UTC()
|
||||
}
|
||||
if e.sessionWatcher == nil {
|
||||
return
|
||||
}
|
||||
// Watcher.Update owns the propagation to the status recorder (the
|
||||
// SubscribeStatus / Status snapshot the UI reads): a set writes the
|
||||
// deadline, a clear or a sanity-check rejection writes the zero value.
|
||||
// Keeping a single writer is what stops the recorder from drifting out
|
||||
// of sync with the warning timers.
|
||||
if err := e.sessionWatcher.Update(deadline); err != nil {
|
||||
log.Errorf("auth session deadline rejected: %v, clearing", err)
|
||||
}
|
||||
}
|
||||
|
||||
// DismissSessionWarning records the user's "Dismiss" click on the
|
||||
// T-WarningLead interactive notification and suppresses the upcoming
|
||||
// T-FinalWarningLead fallback for the current deadline. No-op when the
|
||||
// watcher is not running or holds no deadline.
|
||||
func (e *Engine) DismissSessionWarning() {
|
||||
if e.sessionWatcher == nil {
|
||||
return
|
||||
}
|
||||
e.sessionWatcher.Dismiss()
|
||||
}
|
||||
|
||||
// ExtendAuthSession asks the management server to refresh the SSO session
|
||||
// expiry deadline using the supplied JWT, then mirrors the new deadline into
|
||||
// the daemon's state. The tunnel is untouched; no resync, no reconnect.
|
||||
//
|
||||
// Returns the new absolute UTC deadline (or zero time when the server
|
||||
// reports the peer is not eligible for extension).
|
||||
func (e *Engine) ExtendAuthSession(ctx context.Context, jwtToken string) (time.Time, error) {
|
||||
if jwtToken == "" {
|
||||
return time.Time{}, errors.New("jwt token is required")
|
||||
}
|
||||
if e.mgmClient == nil {
|
||||
return time.Time{}, errors.New("management client is not initialised")
|
||||
}
|
||||
|
||||
info, err := system.GetInfoWithChecks(ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to collect system info for session extend: %v", err)
|
||||
info = system.GetInfo(ctx)
|
||||
}
|
||||
|
||||
resp, err := e.mgmClient.ExtendAuthSession(info, jwtToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("extend auth session on management: %w", err)
|
||||
}
|
||||
|
||||
e.ApplySessionDeadline(resp.GetSessionExpiresAt())
|
||||
|
||||
if resp.GetSessionExpiresAt().IsValid() {
|
||||
return resp.GetSessionExpiresAt().AsTime().UTC(), nil
|
||||
}
|
||||
return time.Time{}, nil
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth/sessionwatch"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
// TestApplySessionDeadline_ThreeState pins down the 3-state semantics of the
|
||||
// wire field carried on LoginResponse / SyncResponse:
|
||||
//
|
||||
// - nil pointer → no info; previously-anchored deadline survives
|
||||
// - explicit zero value → "expiry disabled" sentinel; both sinks cleared
|
||||
// - valid future timestamp → new deadline propagated to both sinks
|
||||
func TestApplySessionDeadline_ThreeState(t *testing.T) {
|
||||
newEngine := func() *Engine {
|
||||
recorder := peer.NewRecorder("")
|
||||
return &Engine{
|
||||
statusRecorder: recorder,
|
||||
sessionWatcher: sessionwatch.New(recorder),
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("valid timestamp sets deadline on both sinks", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
deadline := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
|
||||
e.ApplySessionDeadline(timestamppb.New(deadline))
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(deadline),
|
||||
"status recorder should hold the new deadline")
|
||||
})
|
||||
|
||||
t.Run("nil is a no-op and preserves previous deadline", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
e.ApplySessionDeadline(timestamppb.New(seeded))
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
|
||||
|
||||
e.ApplySessionDeadline(nil)
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded),
|
||||
"nil snapshot must not disturb the existing deadline")
|
||||
})
|
||||
|
||||
t.Run("explicit zero clears a previously-anchored deadline", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
e.ApplySessionDeadline(timestamppb.New(seeded))
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
|
||||
|
||||
// Explicit zero Timestamp{} (seconds=0, nanos=0) is the
|
||||
// "expiry disabled / not SSO" sentinel.
|
||||
e.ApplySessionDeadline(×tamppb.Timestamp{})
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().IsZero(),
|
||||
"explicit zero sentinel must clear the deadline")
|
||||
})
|
||||
|
||||
t.Run("invalid timestamp clears the deadline", func(t *testing.T) {
|
||||
e := newEngine()
|
||||
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
|
||||
e.ApplySessionDeadline(timestamppb.New(seeded))
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
|
||||
|
||||
// Out-of-range nanos → IsValid()==false; same-meaning as the
|
||||
// disabled sentinel for downstream sinks.
|
||||
e.ApplySessionDeadline(×tamppb.Timestamp{Seconds: 1, Nanos: -1})
|
||||
|
||||
require.True(t, e.statusRecorder.GetSessionExpiresAt().IsZero(),
|
||||
"invalid timestamp must clear the deadline")
|
||||
})
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/auth/sessionwatch"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
// newSessionWatcher returns the real SSO session expiry watcher for every
|
||||
// non-wasm build. The js/wasm build gets a no-op stub from
|
||||
// engine_sessionwatch_js.go so the sessionwatch package (and its timer
|
||||
// machinery) never links into the wasm binary.
|
||||
func newSessionWatcher(recorder *peer.Status) sessionDeadlineWatcher {
|
||||
return sessionwatch.New(recorder)
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
// noopSessionWatcher is the js/wasm stand-in for sessionwatch.Watcher. The
|
||||
// wasm client never runs the engine's session-warning flow (the interactive
|
||||
// T-WarningLead notification and the T-FinalWarningLead fallback dialog live
|
||||
// in the desktop UI), so linking the full sessionwatch package (timers, event
|
||||
// composition) would only bloat the binary.
|
||||
//
|
||||
// It still mirrors the deadline into the status recorder so the SubscribeStatus
|
||||
// / Status snapshot the UI consumes stays correct — only the timer-driven
|
||||
// warnings are dropped.
|
||||
type noopSessionWatcher struct {
|
||||
recorder *peer.Status
|
||||
}
|
||||
|
||||
func newSessionWatcher(recorder *peer.Status) sessionDeadlineWatcher {
|
||||
return noopSessionWatcher{recorder: recorder}
|
||||
}
|
||||
|
||||
// Update mirrors the real watcher's recorder propagation without the timers or
|
||||
// sanity-check sentinels: a valid deadline is exposed on the status snapshot,
|
||||
// the zero time clears it.
|
||||
func (w noopSessionWatcher) Update(deadline time.Time) error {
|
||||
if w.recorder != nil {
|
||||
w.recorder.SetSessionExpiresAt(deadline)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (noopSessionWatcher) Dismiss() {}
|
||||
func (noopSessionWatcher) Close() {}
|
||||
@@ -24,14 +24,14 @@ type RulePair struct {
|
||||
type Manager struct {
|
||||
dnatFirewall DNATFirewall
|
||||
|
||||
rules map[string]RulePair // keys is the ID of the ForwardRule
|
||||
rules map[firewall.RuleID]RulePair
|
||||
rulesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(dnatFirewall DNATFirewall) *Manager {
|
||||
return &Manager{
|
||||
dnatFirewall: dnatFirewall,
|
||||
rules: make(map[string]RulePair),
|
||||
rules: make(map[firewall.RuleID]RulePair),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
|
||||
|
||||
var mErr *multierror.Error
|
||||
|
||||
toDelete := make(map[string]RulePair, len(h.rules))
|
||||
toDelete := make(map[firewall.RuleID]RulePair, len(h.rules))
|
||||
for id, r := range h.rules {
|
||||
toDelete[id] = r
|
||||
}
|
||||
@@ -59,6 +59,10 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
|
||||
continue
|
||||
}
|
||||
if rule == nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': backend returned no rule", fwdRule.String()))
|
||||
continue
|
||||
}
|
||||
log.Infof("forward rule has been added '%s'", fwdRule)
|
||||
h.rules[id] = RulePair{
|
||||
ForwardRule: fwdRule,
|
||||
@@ -90,7 +94,7 @@ func (h *Manager) Close() error {
|
||||
}
|
||||
}
|
||||
|
||||
h.rules = make(map[string]RulePair)
|
||||
h.rules = make(map[firewall.RuleID]RulePair)
|
||||
return nberrors.FormatErrorOrNil(mErr)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,11 +14,11 @@ var (
|
||||
)
|
||||
|
||||
type MocFwRule struct {
|
||||
id string
|
||||
id firewall.RuleID
|
||||
}
|
||||
|
||||
func (m *MocFwRule) ID() string {
|
||||
return string(m.id)
|
||||
func (m *MocFwRule) ID() firewall.RuleID {
|
||||
return m.id
|
||||
}
|
||||
|
||||
type MockDNATFirewall struct {
|
||||
|
||||
@@ -10,21 +10,6 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.RuleProtocol_TCP:
|
||||
return firewallManager.ProtocolTCP, nil
|
||||
case mgmProto.RuleProtocol_UDP:
|
||||
return firewallManager.ProtocolUDP, nil
|
||||
case mgmProto.RuleProtocol_ICMP:
|
||||
return firewallManager.ProtocolICMP, nil
|
||||
case mgmProto.RuleProtocol_ALL:
|
||||
return firewallManager.ProtocolALL, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
}
|
||||
}
|
||||
|
||||
func convertPortInfo(portInfo *mgmProto.PortInfo) (*firewallManager.Port, error) {
|
||||
if portInfo == nil {
|
||||
return nil, errors.New("portInfo cannot be nil")
|
||||
|
||||
@@ -5,10 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -192,27 +190,21 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
// every private-service request) don't contend against each other.
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
// sessionExpiresAt is the absolute UTC instant at which the peer's SSO
|
||||
// session expires. Zero when the peer is not SSO-tracked or login
|
||||
// expiration is disabled. Populated from management LoginResponse /
|
||||
// SyncResponse and exposed via the daemon's Status / SubscribeStatus RPC
|
||||
// so the UI can show remaining time without itself talking to mgm.
|
||||
sessionExpiresAt time.Time
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
||||
lazyConnectionEnabled bool
|
||||
@@ -228,21 +220,6 @@ type Status struct {
|
||||
eventStreams map[string]chan *proto.SystemEvent
|
||||
eventQueue *EventQueue
|
||||
|
||||
// stateChangeStreams fan-out connection-state changes (connected /
|
||||
// disconnected / connecting / address change / peers list change) to
|
||||
// every active SubscribeStatus gRPC stream. Each subscriber gets a
|
||||
// buffered chan; the notifier non-blockingly pings them so a slow
|
||||
// consumer can never stall the daemon.
|
||||
stateChangeMux sync.Mutex
|
||||
stateChangeStreams map[string]chan struct{}
|
||||
|
||||
// networksRevision bumps whenever the routed-networks set or their
|
||||
// selected state changes (driven by the route manager). Surfaced in the
|
||||
// status snapshot so the UI can fingerprint on it and re-fetch
|
||||
// ListNetworks only on a real change. Atomic so the snapshot builder can
|
||||
// read it without taking mux.
|
||||
networksRevision atomic.Uint64
|
||||
|
||||
ingressGwMgr *ingressgw.Manager
|
||||
|
||||
routeIDLookup routeIDLookup
|
||||
@@ -256,7 +233,6 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
|
||||
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||
eventQueue: NewEventQueue(eventQueueSize),
|
||||
stateChangeStreams: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
@@ -406,7 +382,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -432,7 +407,6 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -458,7 +432,6 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifier.peerListChanged(numPeers)
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -508,7 +481,6 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -545,7 +517,6 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -581,7 +552,6 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -620,7 +590,6 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
if notifyRouter {
|
||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -714,7 +683,6 @@ func (d *Status) FinishPeerListModifications() {
|
||||
for _, rd := range dispatches {
|
||||
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
|
||||
}
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription {
|
||||
@@ -773,41 +741,6 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// SetSessionExpiresAt records the absolute UTC instant at which the peer's
|
||||
// SSO session is set to expire. Pass the zero value to clear (e.g. when the
|
||||
// management server stops publishing a deadline because login expiration was
|
||||
// disabled or the peer is not SSO-tracked). Same-value updates are no-ops;
|
||||
// real changes fan out via notifyStateChange so SubscribeStatus consumers
|
||||
// pick up the new deadline on their next read.
|
||||
func (d *Status) SetSessionExpiresAt(deadline time.Time) {
|
||||
d.mux.Lock()
|
||||
if d.sessionExpiresAt.Equal(deadline) {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.sessionExpiresAt = deadline
|
||||
d.mux.Unlock()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// GetSessionExpiresAt returns the most recently recorded SSO session deadline,
|
||||
// or the zero value when no deadline is tracked. A deadline that has already
|
||||
// slipped into the past reports as "none": once the session has expired it is
|
||||
// no longer a meaningful countdown, and the sessionwatch.Watcher does not
|
||||
// arm a timer at the deadline itself to clear it (only the two pre-expiry
|
||||
// warnings). Without this guard the UI would keep painting a stale
|
||||
// "expires in …" against a moment that has passed until the next login,
|
||||
// extend, or teardown rewrote the value.
|
||||
func (d *Status) GetSessionExpiresAt() time.Time {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
if !d.sessionExpiresAt.IsZero() && d.sessionExpiresAt.Before(time.Now()) {
|
||||
return time.Time{}
|
||||
}
|
||||
return d.sessionExpiresAt
|
||||
}
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
@@ -876,19 +809,11 @@ func (d *Status) CleanLocalPeerState() {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.localAddressChanged(fqdn, ip)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
// Health checks re-mark the same state on every probe; skip the fan-out
|
||||
// when nothing actually changed so we don't flood SubscribeStatus
|
||||
// consumers with identical snapshots.
|
||||
if !d.managementState && errors.Is(d.managementError, err) {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.managementState = false
|
||||
d.managementError = err
|
||||
mgm := d.managementState
|
||||
@@ -896,16 +821,11 @@ func (d *Status) MarkManagementDisconnected(err error) {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Lock()
|
||||
if d.managementState && d.managementError == nil {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.managementState = true
|
||||
d.managementError = nil
|
||||
mgm := d.managementState
|
||||
@@ -913,7 +833,6 @@ func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
@@ -947,10 +866,6 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.mux.Lock()
|
||||
if !d.signalState && errors.Is(d.signalError, err) {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.signalState = false
|
||||
d.signalError = err
|
||||
mgm := d.managementState
|
||||
@@ -958,16 +873,11 @@ func (d *Status) MarkSignalDisconnected(err error) {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Lock()
|
||||
if d.signalState && d.signalError == nil {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.signalState = true
|
||||
d.signalError = nil
|
||||
mgm := d.managementState
|
||||
@@ -975,7 +885,6 @@ func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifier.updateServerStates(mgm, sig)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
@@ -1173,19 +1082,16 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
// ClientStart will notify all listeners about the new service state
|
||||
func (d *Status) ClientStart() {
|
||||
d.notifier.clientStart()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// ClientStop will notify all listeners about the new service state
|
||||
func (d *Status) ClientStop() {
|
||||
d.notifier.clientStop()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// ClientTeardown will notify all listeners about the service is under teardown
|
||||
func (d *Status) ClientTeardown() {
|
||||
d.notifier.clientTearDown()
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// SetConnectionListener set a listener to the notifier
|
||||
@@ -1327,82 +1233,6 @@ func (d *Status) GetEventHistory() []*proto.SystemEvent {
|
||||
return d.eventQueue.GetAll()
|
||||
}
|
||||
|
||||
// SubscribeToStateChanges hands back a channel that receives a tick on
|
||||
// every connection-state change (connected / disconnected / connecting /
|
||||
// address change / peers-list change). The channel is buffered to one
|
||||
// pending tick so a coalesced burst still wakes the consumer exactly
|
||||
// once. Pass the returned id to UnsubscribeFromStateChanges to detach.
|
||||
func (d *Status) SubscribeToStateChanges() (string, <-chan struct{}) {
|
||||
d.stateChangeMux.Lock()
|
||||
defer d.stateChangeMux.Unlock()
|
||||
|
||||
id := uuid.New().String()
|
||||
ch := make(chan struct{}, 1)
|
||||
d.stateChangeStreams[id] = ch
|
||||
return id, ch
|
||||
}
|
||||
|
||||
// UnsubscribeFromStateChanges releases a SubscribeToStateChanges channel
|
||||
// and closes it so any consumer goroutine selecting on the channel
|
||||
// unblocks cleanly.
|
||||
func (d *Status) UnsubscribeFromStateChanges(id string) {
|
||||
d.stateChangeMux.Lock()
|
||||
defer d.stateChangeMux.Unlock()
|
||||
|
||||
if ch, ok := d.stateChangeStreams[id]; ok {
|
||||
close(ch)
|
||||
delete(d.stateChangeStreams, id)
|
||||
}
|
||||
}
|
||||
|
||||
// notifyStateChange wakes every SubscribeToStateChanges subscriber. Drops
|
||||
// the tick if a subscriber's buffer is full — by definition the consumer
|
||||
// is already going to fetch the latest snapshot, so multiple pending ticks
|
||||
// would be redundant.
|
||||
func (d *Status) notifyStateChange() {
|
||||
if _, file, line, ok := runtime.Caller(1); ok {
|
||||
log.Infof("--- notifyStateChange from %s:%d", file, line)
|
||||
}
|
||||
d.stateChangeMux.Lock()
|
||||
defer d.stateChangeMux.Unlock()
|
||||
|
||||
for _, ch := range d.stateChangeStreams {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyStateChange is the public wake-the-subscribers entry point used by
|
||||
// callers that mutate state outside the peer recorder — most importantly
|
||||
// the connect-state machine, which writes StatusNeedsLogin into the
|
||||
// shared contextState (client/internal/state.go) without touching any
|
||||
// recorder field. Without this push the SubscribeStatus stream stays on
|
||||
// the previous snapshot until an unrelated peer/management/signal
|
||||
// change happens to fire notifyStateChange, leaving the UI's status
|
||||
// out of sync with the daemon.
|
||||
func (d *Status) NotifyStateChange() {
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// BumpNetworksRevision increments the routed-networks revision and wakes every
|
||||
// SubscribeStatus subscriber. The route manager calls it when a network map
|
||||
// changes the available routes or when a selection is applied — the peer
|
||||
// status itself only records actively-routed (chosen) networks, so without
|
||||
// this bump a candidate route appearing/disappearing would never reach the UI.
|
||||
func (d *Status) BumpNetworksRevision() {
|
||||
d.networksRevision.Add(1)
|
||||
d.notifyStateChange()
|
||||
}
|
||||
|
||||
// GetNetworksRevision returns the current routed-networks revision, surfaced in
|
||||
// the status snapshot so the UI can detect route/selection changes (see
|
||||
// BumpNetworksRevision).
|
||||
func (d *Status) GetNetworksRevision() uint64 {
|
||||
return d.networksRevision.Load()
|
||||
}
|
||||
|
||||
func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -275,39 +275,3 @@ func TestGetFullStatus(t *testing.T) {
|
||||
assert.Equal(t, signalState, fullStatus.SignalState, "signal status should be equal")
|
||||
assert.ElementsMatch(t, []State{peerState1, peerState2}, fullStatus.Peers, "peers states should match")
|
||||
}
|
||||
|
||||
// notified reports whether a state-change tick is pending on ch, draining it.
|
||||
func notified(ch <-chan struct{}) bool {
|
||||
select {
|
||||
case <-ch:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkServerStateDoesNotNotifyWhenUnchanged(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
_, ch := status.SubscribeToStateChanges()
|
||||
|
||||
// First transition is a real change and must notify.
|
||||
status.MarkManagementConnected()
|
||||
require.True(t, notified(ch), "first connect should notify")
|
||||
|
||||
// Re-marking the same state must not notify again.
|
||||
status.MarkManagementConnected()
|
||||
assert.False(t, notified(ch), "redundant connect should not notify")
|
||||
|
||||
// Same for signal.
|
||||
status.MarkSignalConnected()
|
||||
require.True(t, notified(ch), "first signal connect should notify")
|
||||
status.MarkSignalConnected()
|
||||
assert.False(t, notified(ch), "redundant signal connect should not notify")
|
||||
|
||||
// A genuine change (disconnect with an error) notifies again.
|
||||
err := errors.New("boom")
|
||||
status.MarkManagementDisconnected(err)
|
||||
require.True(t, notified(ch), "disconnect should notify")
|
||||
status.MarkManagementDisconnected(err)
|
||||
assert.False(t, notified(ch), "redundant disconnect should not notify")
|
||||
}
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func newExitNodeTestManager() *DefaultManager {
|
||||
return &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
}
|
||||
|
||||
func exitRoute(netID, peer string, skipAutoApply bool) *route.Route {
|
||||
return &route.Route{
|
||||
NetID: route.NetID(netID),
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
Peer: peer,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPickPreferredExitNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
info exitNodeInfo
|
||||
want route.NetID
|
||||
}{
|
||||
{
|
||||
name: "persisted user selection wins over management",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
userSelected: []route.NetID{"b"},
|
||||
selectedByManagement: []route.NetID{"a"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "multiple user-selected self-heal to deterministic min",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
userSelected: []route.NetID{"c", "a"},
|
||||
},
|
||||
want: "a",
|
||||
},
|
||||
{
|
||||
name: "explicit opt-out keeps none",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b"},
|
||||
userDeselected: []route.NetID{"a", "b"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "fresh defaults to management auto-apply pick",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
selectedByManagement: []route.NetID{"b"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "no user pick and no management auto-apply selects none",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"c", "a", "b"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "user-deselect does not block a management auto-apply sibling",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b"},
|
||||
userDeselected: []route.NetID{"a"},
|
||||
selectedByManagement: []route.NetID{"b"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, pickPreferredExitNode(tt.info), "preferred exit node")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceSingleExitNode(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
all := []route.NetID{"a", "b", "c"}
|
||||
|
||||
m.enforceSingleExitNode("b", all)
|
||||
assert.False(t, m.routeSelector.IsSelected("a"), "a should be deselected")
|
||||
assert.True(t, m.routeSelector.IsSelected("b"), "b should be the only selected exit node")
|
||||
assert.False(t, m.routeSelector.IsSelected("c"), "c should be deselected")
|
||||
|
||||
// Switching the preferred node moves the single selection.
|
||||
m.enforceSingleExitNode("c", all)
|
||||
assert.False(t, m.routeSelector.IsSelected("a"), "a stays deselected")
|
||||
assert.False(t, m.routeSelector.IsSelected("b"), "b should now be deselected")
|
||||
assert.True(t, m.routeSelector.IsSelected("c"), "c should now be selected")
|
||||
|
||||
// Empty preferred turns every exit node off.
|
||||
m.enforceSingleExitNode("", all)
|
||||
for _, id := range all {
|
||||
assert.False(t, m.routeSelector.IsSelected(id), "no exit node should be selected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceSingleExitNode_RespectsDeselectAll(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
|
||||
m.enforceSingleExitNode("b", []route.NetID{"a", "b"})
|
||||
|
||||
assert.True(t, m.routeSelector.IsDeselectAllActive(), "global deselect-all must stay in effect")
|
||||
assert.False(t, m.routeSelector.IsSelected("b"), "no exit node should be forced on while deselect-all is set")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_FreshSelectsOne(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
"lan|192.168.1.0/24": {{NetID: "lan", Network: netip.MustParsePrefix("192.168.1.0/24"), Peer: "p3"}},
|
||||
"exitC|0.0.0.0/0": {exitRoute("exitC", "p4", false)},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
// Exactly one exit node (the deterministic first) is selected.
|
||||
assert.True(t, m.routeSelector.IsSelected("exitA"), "exitA is the deterministic default")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "exitB must not also be selected")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitC"), "exitC must not also be selected")
|
||||
// Non-exit routes are left at their default-on state.
|
||||
assert.True(t, m.routeSelector.IsSelected("lan"), "non-exit route selection is untouched")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_HonorsPersistedPick(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
}
|
||||
all := []route.NetID{"exitA", "exitB"}
|
||||
|
||||
// Simulate the state the runtime select path leaves behind: exactly one
|
||||
// exit node explicitly selected, its sibling deselected.
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exitB"}, true, all))
|
||||
require.NoError(t, m.routeSelector.DeselectRoutes([]route.NetID{"exitA"}, all))
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.True(t, m.routeSelector.IsSelected("exitB"), "persisted pick must stay selected")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "the other exit node stays deselected")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_OptOutKeepsNone(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
}
|
||||
all := []route.NetID{"exitA", "exitB"}
|
||||
|
||||
// User deselected exit nodes and selected none.
|
||||
require.NoError(t, m.routeSelector.DeselectRoutes(all, all))
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "opt-out keeps exitA off")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "opt-out keeps exitB off")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_NoAutoApplySelectsNone(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
// SkipAutoApply=true: management offers the exit nodes but doesn't request
|
||||
// auto-activation, so none should be selected until the user picks one.
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", true)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", true)},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "no auto-apply keeps exitA off")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "no auto-apply keeps exitB off")
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -440,11 +439,6 @@ func (m *DefaultManager) UpdateRoutes(
|
||||
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||
// A new network map can add or drop route/exit-node candidates without
|
||||
// touching any peer's chosen-route state, so the peer status alone
|
||||
// wouldn't notify SubscribeStatus subscribers. Bump the revision so the
|
||||
// UI re-fetches ListNetworks.
|
||||
m.statusRecorder.BumpNetworksRevision()
|
||||
}
|
||||
m.clientRoutes = clientRoutes
|
||||
|
||||
@@ -585,10 +579,6 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
// A selection change flips Network.selected without altering the candidate
|
||||
// set, so bump the revision to push the new state to the UI.
|
||||
m.statusRecorder.BumpNetworksRevision()
|
||||
}
|
||||
|
||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||
@@ -708,22 +698,15 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
return ips
|
||||
}
|
||||
|
||||
// updateRouteSelectorFromManagement reconciles exit-node selection on every
|
||||
// network map: it keeps at most one exit node selected — the user's persisted
|
||||
// pick, else whatever management marks for auto-apply (SkipAutoApply=false),
|
||||
// else none. We never auto-activate an exit node the map doesn't request; it
|
||||
// stays off until the user picks it. Exit nodes are mutually exclusive, but the
|
||||
// RouteSelector stores routes with default-on semantics, so without this every
|
||||
// available exit node would report selected at once.
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
info := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(info.allIDs) == 0 {
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
preferred := pickPreferredExitNode(info)
|
||||
m.enforceSingleExitNode(preferred, info.allIDs)
|
||||
m.logExitNodeUpdate(info, preferred)
|
||||
m.updateExitNodeSelections(exitNodeInfo)
|
||||
m.logExitNodeUpdate(exitNodeInfo)
|
||||
}
|
||||
|
||||
type exitNodeInfo struct {
|
||||
@@ -733,10 +716,6 @@ type exitNodeInfo struct {
|
||||
userDeselected []route.NetID
|
||||
}
|
||||
|
||||
// collectExitNodeInfo categorises the available exit nodes by their persisted
|
||||
// selection state. It keys on the base (v4) NetID and skips the synthesized
|
||||
// "-v6" partner, which inherits its base's selection through the RouteSelector
|
||||
// — counting it separately would double-count the pair.
|
||||
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
|
||||
var info exitNodeInfo
|
||||
|
||||
@@ -746,9 +725,6 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
|
||||
}
|
||||
|
||||
netID := haID.NetID()
|
||||
if strings.HasSuffix(string(netID), route.V6ExitSuffix) {
|
||||
continue
|
||||
}
|
||||
info.allIDs = append(info.allIDs, netID)
|
||||
|
||||
if m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||
@@ -785,69 +761,45 @@ func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID r
|
||||
}
|
||||
}
|
||||
|
||||
// pickPreferredExitNode chooses the single exit node to keep selected. In order:
|
||||
// - a persisted user selection wins (deterministic if several survive from
|
||||
// legacy state, so the set self-heals down to one);
|
||||
// - otherwise activate only what management marks for auto-apply
|
||||
// (SkipAutoApply=false); the lexicographically first if it marks several.
|
||||
//
|
||||
// Returns "" when neither holds — we never force an arbitrary exit node on. A
|
||||
// route the map doesn't auto-apply stays off until the user selects it.
|
||||
// info.userDeselected is informational only: an explicit deselect simply keeps
|
||||
// that route out of both lists above, so it can't be picked.
|
||||
func pickPreferredExitNode(info exitNodeInfo) route.NetID {
|
||||
if len(info.userSelected) > 0 {
|
||||
return minNetID(info.userSelected)
|
||||
}
|
||||
if len(info.selectedByManagement) > 0 {
|
||||
return minNetID(info.selectedByManagement)
|
||||
}
|
||||
return ""
|
||||
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
|
||||
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
|
||||
m.deselectExitNodes(routesToDeselect)
|
||||
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
|
||||
}
|
||||
|
||||
// enforceSingleExitNode makes preferred the only selected exit node: every other
|
||||
// available exit node is deselected and preferred (if any) is selected, without
|
||||
// disturbing non-exit route selections. A global deselect-all is left untouched
|
||||
// so the user's "all off" stays in effect.
|
||||
func (m *DefaultManager) enforceSingleExitNode(preferred route.NetID, allIDs []route.NetID) {
|
||||
if m.routeSelector.IsDeselectAllActive() {
|
||||
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
|
||||
var routesToDeselect []route.NetID
|
||||
for _, netID := range allIDs {
|
||||
if !m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||
routesToDeselect = append(routesToDeselect, netID)
|
||||
}
|
||||
}
|
||||
return routesToDeselect
|
||||
}
|
||||
|
||||
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
|
||||
if len(routesToDeselect) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
others := make([]route.NetID, 0, len(allIDs))
|
||||
for _, id := range allIDs {
|
||||
if id != preferred {
|
||||
others = append(others, id)
|
||||
}
|
||||
}
|
||||
if len(others) > 0 {
|
||||
if err := m.routeSelector.DeselectRoutes(others, allIDs); err != nil {
|
||||
log.Warnf("deselect other exit nodes: %v", err)
|
||||
}
|
||||
}
|
||||
if preferred != "" {
|
||||
if err := m.routeSelector.SelectRoutes([]route.NetID{preferred}, true, allIDs); err != nil {
|
||||
log.Warnf("select preferred exit node %q: %v", preferred, err)
|
||||
}
|
||||
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to deselect exit nodes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo, preferred route.NetID) {
|
||||
log.Debugf("Exit node selection: %d available, preferred=%q (%d user-selected, %d user-deselected, %d management-selected)",
|
||||
len(info.allIDs), preferred, len(info.userSelected), len(info.userDeselected), len(info.selectedByManagement))
|
||||
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
|
||||
if len(selectedByManagement) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to select exit nodes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// minNetID returns the lexicographically smallest NetID, for a deterministic
|
||||
// default pick that stays stable across restarts.
|
||||
func minNetID(ids []route.NetID) route.NetID {
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
best := ids[0]
|
||||
for _, id := range ids[1:] {
|
||||
if id < best {
|
||||
best = id
|
||||
}
|
||||
}
|
||||
return best
|
||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
|
||||
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
|
||||
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
|
||||
}
|
||||
|
||||
@@ -124,16 +124,6 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
return rs.isSelectedLocked(routeID)
|
||||
}
|
||||
|
||||
// IsDeselectAllActive reports whether the global "deselect all" flag is set,
|
||||
// i.e. the user disabled every route. Callers enforcing per-route invariants
|
||||
// (e.g. single exit node) should leave the selection untouched when it is.
|
||||
func (rs *RouteSelector) IsDeselectAllActive() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -2,10 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type StatusType string
|
||||
@@ -36,37 +33,17 @@ func CtxGetState(ctx context.Context) *contextState {
|
||||
}
|
||||
|
||||
type contextState struct {
|
||||
err error
|
||||
status StatusType
|
||||
mutex sync.Mutex
|
||||
onChange func()
|
||||
}
|
||||
|
||||
// SetOnChange installs a callback fired after every successful Set. Used by
|
||||
// the daemon to wire the status recorder's notifyStateChange so any
|
||||
// state.Set in the connect/login paths pushes a fresh snapshot to
|
||||
// SubscribeStatus subscribers without each callsite having to opt in.
|
||||
// The callback runs outside the contextState mutex to avoid a lock-order
|
||||
// dependency with the recorder's stateChangeMux.
|
||||
func (c *contextState) SetOnChange(fn func()) {
|
||||
c.mutex.Lock()
|
||||
c.onChange = fn
|
||||
c.mutex.Unlock()
|
||||
err error
|
||||
status StatusType
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func (c *contextState) Set(update StatusType) {
|
||||
if _, file, line, ok := runtime.Caller(1); ok {
|
||||
log.Infof("--- state.Set(%s) from %s:%d", update, file, line)
|
||||
}
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.status = update
|
||||
c.err = nil
|
||||
cb := c.onChange
|
||||
c.mutex.Unlock()
|
||||
|
||||
if cb != nil {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *contextState) Status() (StatusType, error) {
|
||||
@@ -80,17 +57,6 @@ func (c *contextState) Status() (StatusType, error) {
|
||||
return c.status, nil
|
||||
}
|
||||
|
||||
// CurrentStatus returns the last status set via Set, ignoring any wrapped
|
||||
// error. Use when the status is needed for reporting purposes (e.g. the
|
||||
// status snapshot stream) and a transient wrapped error from a retry loop
|
||||
// shouldn't blank out the underlying status.
|
||||
func (c *contextState) CurrentStatus() StatusType {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *contextState) Wrap(err error) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@@ -32,6 +32,9 @@
|
||||
</File>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
||||
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
|
||||
<?if $(var.ArchSuffix) = "amd64" ?>
|
||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
||||
<?endif ?>
|
||||
|
||||
<ServiceInstall
|
||||
Id="NetBirdService"
|
||||
@@ -59,14 +62,6 @@
|
||||
<Component Id="NetbirdAumidRegistry" Guid="*">
|
||||
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
|
||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
||||
<!-- Pre-seed the CLSID the Wails notifications service reads on
|
||||
first startup (notifications_windows.go:getGUID looks for
|
||||
the CustomActivator value under this key). Without this
|
||||
the service generates a fresh per-install UUID, which
|
||||
diverges from the ToastActivatorCLSID set on the Start
|
||||
Menu / Desktop shortcuts above and the COM activator
|
||||
never fires when a toast is clicked. -->
|
||||
<RegistryValue Name="CustomActivator" Type="string" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
||||
</RegistryKey>
|
||||
</Component>
|
||||
</StandardDirectory>
|
||||
@@ -90,40 +85,10 @@
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" TerminateProcess="0" />
|
||||
|
||||
<!-- WebView2 evergreen runtime detection.
|
||||
Probe both the per-machine and per-user EdgeUpdate keys; if either
|
||||
reports a non-empty `pv` value the runtime is already installed
|
||||
and we skip the bootstrapper. -->
|
||||
<Property Id="WEBVIEW2_VERSION_HKLM">
|
||||
<RegistrySearch Id="WV2HKLM" Root="HKLM"
|
||||
Key="SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
Name="pv" Type="raw" Bitness="always64" />
|
||||
</Property>
|
||||
<Property Id="WEBVIEW2_VERSION_HKCU">
|
||||
<RegistrySearch Id="WV2HKCU" Root="HKCU"
|
||||
Key="Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
Name="pv" Type="raw" />
|
||||
</Property>
|
||||
|
||||
<!-- Embed the bootstrapper payload. Path is relative to the WiX
|
||||
working directory; sign-pipelines stages it next to client/
|
||||
via `wails3 generate webview2bootstrapper`. -->
|
||||
<Binary Id="WebView2Bootstrapper" SourceFile=".\client\MicrosoftEdgeWebview2Setup.exe" />
|
||||
|
||||
<CustomAction Id="InstallWebView2"
|
||||
BinaryRef="WebView2Bootstrapper"
|
||||
ExeCommand="/silent /install"
|
||||
Execute="deferred"
|
||||
Impersonate="no"
|
||||
Return="check" />
|
||||
|
||||
<InstallExecuteSequence>
|
||||
<Custom Action="InstallWebView2" Before="InstallFinalize"
|
||||
Condition="NOT WEBVIEW2_VERSION_HKLM AND NOT WEBVIEW2_VERSION_HKCU AND NOT REMOVE" />
|
||||
</InstallExecuteSequence>
|
||||
|
||||
<!-- Icons -->
|
||||
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\build\windows\icon.ico" />
|
||||
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" />
|
||||
<Property Id="ARPPRODUCTICON" Value="NetbirdIcon" />
|
||||
|
||||
</Package>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,12 +24,6 @@ service DaemonService {
|
||||
// Status of the service.
|
||||
rpc Status(StatusRequest) returns (StatusResponse) {}
|
||||
|
||||
// SubscribeStatus pushes a fresh StatusResponse on connection state
|
||||
// changes (Connected / Disconnected / Connecting / address change /
|
||||
// peers list change). The first message on the stream is the current
|
||||
// snapshot, so a freshly-subscribed UI doesn't need to also call Status.
|
||||
rpc SubscribeStatus(StatusRequest) returns (stream StatusResponse) {}
|
||||
|
||||
// Down stops engine work in the daemon.
|
||||
rpc Down(DownRequest) returns (DownResponse) {}
|
||||
|
||||
@@ -115,25 +109,6 @@ service DaemonService {
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||
|
||||
// RequestExtendAuthSession initiates an SSO session-extension flow.
|
||||
// The daemon prepares a PKCE/device-code request against the IdP and
|
||||
// returns the verification URI; the UI is expected to open it. The flow
|
||||
// state is kept in the daemon until WaitExtendAuthSession completes it.
|
||||
rpc RequestExtendAuthSession(RequestExtendAuthSessionRequest) returns (RequestExtendAuthSessionResponse) {}
|
||||
|
||||
// WaitExtendAuthSession blocks until the user finishes the SSO step
|
||||
// started by RequestExtendAuthSession, then forwards the resulting JWT
|
||||
// to the management server's ExtendAuthSession RPC. Returns the new
|
||||
// session expiry deadline. The tunnel stays up the entire time.
|
||||
rpc WaitExtendAuthSession(WaitExtendAuthSessionRequest) returns (WaitExtendAuthSessionResponse) {}
|
||||
|
||||
// DismissSessionWarning records that the user clicked "Dismiss" on the
|
||||
// T-WarningLead interactive notification, suppressing the auto-opened
|
||||
// SessionAboutToExpire dialog that would otherwise fire at
|
||||
// T-FinalWarningLead for the current deadline. Idempotent and best-effort:
|
||||
// a missed call only means the fallback dialog will still appear.
|
||||
rpc DismissSessionWarning(DismissSessionWarningRequest) returns (DismissSessionWarningResponse) {}
|
||||
|
||||
// StartCPUProfile starts CPU profiling in the daemon
|
||||
rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {}
|
||||
|
||||
@@ -252,12 +227,6 @@ message UpRequest {
|
||||
optional string profileName = 1;
|
||||
optional string username = 2;
|
||||
reserved 3;
|
||||
// async instructs the daemon to start the connection attempt and return
|
||||
// immediately without waiting for the engine to become ready. Status updates
|
||||
// are delivered via the SubscribeStatus stream. When false (the default) the
|
||||
// RPC blocks until the engine is running or gives up, which is the behaviour
|
||||
// needed by the CLI.
|
||||
bool async = 4;
|
||||
}
|
||||
|
||||
message UpResponse {}
|
||||
@@ -275,10 +244,6 @@ message StatusResponse{
|
||||
FullStatus fullStatus = 2;
|
||||
// NetBird daemon version
|
||||
string daemonVersion = 3;
|
||||
// Absolute UTC instant at which the peer's SSO session expires.
|
||||
// Unset when the peer is not SSO-registered or login expiration is disabled.
|
||||
// The UI derives "warning active" from this value and its own clock.
|
||||
google.protobuf.Timestamp sessionExpiresAt = 4;
|
||||
}
|
||||
|
||||
message DownRequest {}
|
||||
@@ -443,12 +408,6 @@ message FullStatus {
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
|
||||
// networksRevision bumps whenever the set of routed networks (route and
|
||||
// exit-node candidates) or their selected state changes. The UI fingerprints
|
||||
// on it to know when to re-fetch ListNetworks via the push stream, instead
|
||||
// of polling on every status snapshot.
|
||||
uint64 networksRevision = 11;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -839,55 +798,6 @@ message WaitJWTTokenResponse {
|
||||
int64 expiresIn = 3;
|
||||
}
|
||||
|
||||
// RequestExtendAuthSessionRequest kicks off the session-extension SSO flow.
|
||||
message RequestExtendAuthSessionRequest {
|
||||
// Optional OIDC login_hint (typically the user's email) to pre-fill the
|
||||
// IdP login form.
|
||||
optional string hint = 1;
|
||||
}
|
||||
|
||||
// RequestExtendAuthSessionResponse carries the verification URI the UI
|
||||
// should open in a browser. The daemon retains the flow state and resolves
|
||||
// it via WaitExtendAuthSession.
|
||||
message RequestExtendAuthSessionResponse {
|
||||
// verification URI for the user to open in the browser
|
||||
string verificationURI = 1;
|
||||
// complete verification URI (with embedded user code)
|
||||
string verificationURIComplete = 2;
|
||||
// user code to enter on verification URI (for device-code flows)
|
||||
string userCode = 3;
|
||||
// device code for matching the WaitExtendAuthSession call to this flow
|
||||
string deviceCode = 4;
|
||||
// expiration time in seconds for the device code / PKCE flow
|
||||
int64 expiresIn = 5;
|
||||
}
|
||||
|
||||
// WaitExtendAuthSessionRequest is sent by the UI after it opens the
|
||||
// verification URI. The daemon blocks on this call until the user
|
||||
// completes (or aborts) the SSO step.
|
||||
message WaitExtendAuthSessionRequest {
|
||||
// device code returned by RequestExtendAuthSession
|
||||
string deviceCode = 1;
|
||||
// user code for verification
|
||||
string userCode = 2;
|
||||
}
|
||||
|
||||
// WaitExtendAuthSessionResponse carries the refreshed deadline returned
|
||||
// by the management server. Unset when the management server reports the
|
||||
// peer is not eligible for session extension.
|
||||
message WaitExtendAuthSessionResponse {
|
||||
google.protobuf.Timestamp sessionExpiresAt = 1;
|
||||
}
|
||||
|
||||
// DismissSessionWarningRequest is sent by the UI when the user clicks
|
||||
// "Dismiss" on the T-WarningLead notification.
|
||||
message DismissSessionWarningRequest {}
|
||||
|
||||
// DismissSessionWarningResponse acknowledges the dismissal. Carries no
|
||||
// payload — the daemon's only obligation is to silence the upcoming
|
||||
// T-FinalWarningLead fallback for the current deadline.
|
||||
message DismissSessionWarningResponse {}
|
||||
|
||||
// StartCPUProfileRequest for starting CPU profiling
|
||||
message StartCPUProfileRequest {}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -52,10 +52,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
if engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
// Background ctx: the bundle wants a full, fresh probe regardless
|
||||
// of the DebugBundle RPC client's lifetime. The engine's own ctx
|
||||
// still aborts it on shutdown.
|
||||
engine.RunHealthProbes(context.Background(), true)
|
||||
engine.RunHealthProbes(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,17 +172,6 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
|
||||
return nil, fmt.Errorf("select routes: %w", err)
|
||||
}
|
||||
|
||||
// Exit nodes are mutually exclusive: if this selection activates an
|
||||
// exit node, deselect every other available exit node so two can't be
|
||||
// selected at once. Non-exit route selections are left untouched.
|
||||
if requestActivatesExitNode(routes, routesMap) {
|
||||
if others := otherExitNodeIDs(routesMap, routes); len(others) > 0 {
|
||||
if err := routeSelector.DeselectRoutes(others, netIdRoutes); err != nil {
|
||||
return nil, fmt.Errorf("deselect sibling exit nodes: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||
|
||||
@@ -260,38 +249,3 @@ func toNetIDs(routes []string) []route.NetID {
|
||||
}
|
||||
return netIDs
|
||||
}
|
||||
|
||||
func isExitNodeRoutes(routes []*route.Route) bool {
|
||||
return len(routes) > 0 && (route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network))
|
||||
}
|
||||
|
||||
// requestActivatesExitNode reports whether any requested NetID maps to an exit
|
||||
// node (default route) in the current route table.
|
||||
func requestActivatesExitNode(requested []route.NetID, routesMap map[route.NetID][]*route.Route) bool {
|
||||
for _, id := range requested {
|
||||
if isExitNodeRoutes(routesMap[id]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// otherExitNodeIDs returns every available exit-node NetID that is not in the
|
||||
// requested set — the siblings to deselect so a single exit node stays active.
|
||||
func otherExitNodeIDs(routesMap map[route.NetID][]*route.Route, requested []route.NetID) []route.NetID {
|
||||
keep := make(map[route.NetID]struct{}, len(requested))
|
||||
for _, id := range requested {
|
||||
keep[id] = struct{}{}
|
||||
}
|
||||
var others []route.NetID
|
||||
for id, routes := range routesMap {
|
||||
if !isExitNodeRoutes(routes) {
|
||||
continue
|
||||
}
|
||||
if _, ok := keep[id]; ok {
|
||||
continue
|
||||
}
|
||||
others = append(others, id)
|
||||
}
|
||||
return others
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestExitNodeSelectionHelpers(t *testing.T) {
|
||||
routesMap := map[route.NetID][]*route.Route{
|
||||
"exitA": {{Network: netip.MustParsePrefix("0.0.0.0/0")}},
|
||||
"exitB": {{Network: netip.MustParsePrefix("::/0")}},
|
||||
"lan": {{Network: netip.MustParsePrefix("192.168.0.0/16")}},
|
||||
}
|
||||
|
||||
assert.True(t, requestActivatesExitNode([]route.NetID{"exitA"}, routesMap), "v4 default route is an exit node")
|
||||
assert.True(t, requestActivatesExitNode([]route.NetID{"exitB"}, routesMap), "v6 default route is an exit node")
|
||||
assert.False(t, requestActivatesExitNode([]route.NetID{"lan"}, routesMap), "lan route is not an exit node")
|
||||
assert.False(t, requestActivatesExitNode([]route.NetID{"missing"}, routesMap), "unknown id is not an exit node")
|
||||
|
||||
others := otherExitNodeIDs(routesMap, []route.NetID{"exitB"})
|
||||
assert.ElementsMatch(t, []route.NetID{"exitA"}, others, "only the other exit node is a sibling; the lan route is ignored")
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// healthProbeRunner runs the full, expensive probe (network round-trips to
|
||||
// management, signal and the relays) and reports whether every component was
|
||||
// healthy. ctx cancels the probe when the caller gives up. Satisfied by
|
||||
// *internal.Engine.
|
||||
type healthProbeRunner interface {
|
||||
RunHealthProbes(ctx context.Context, waitForResult bool) bool
|
||||
}
|
||||
|
||||
// statsRefresher does the cheap WireGuard-stats refresh callers fall back to
|
||||
// when a fresh probe isn't warranted. Satisfied by *peer.Status.
|
||||
type statsRefresher interface {
|
||||
RefreshWireGuardStats() error
|
||||
}
|
||||
|
||||
// probeThrottle rate-limits and single-flights the daemon's health probes.
|
||||
//
|
||||
// Health probes are expensive (network round-trips to management, signal and
|
||||
// the relays), while Status(GetFullPeerStatus=true) RPCs can arrive frequently
|
||||
// and concurrently — the desktop UI alone issues one per connect/disconnect.
|
||||
// probeThrottle keeps that load bounded with two rules:
|
||||
//
|
||||
// - Single-flight: only one probe runs at a time. Callers that pile up while
|
||||
// a probe is in flight share its result instead of each launching another,
|
||||
// even when that probe failed. A failed probe therefore does not make every
|
||||
// waiter re-probe in turn; the next, non-overlapping caller can try again.
|
||||
// - Throttle: after a fully successful probe the result is cached for
|
||||
// interval. While any component is unhealthy the cache is not advanced, so
|
||||
// later callers keep probing frequently and notice recovery quickly — the
|
||||
// intentional "probe often while unhealthy" behaviour from the original
|
||||
// design.
|
||||
type probeThrottle struct {
|
||||
interval time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
lastOK time.Time // last fully-successful probe; drives the throttle window
|
||||
completedAt time.Time // when the most recent probe finished; drives single-flight sharing
|
||||
}
|
||||
|
||||
func newProbeThrottle(interval time.Duration) *probeThrottle {
|
||||
return &probeThrottle{interval: interval}
|
||||
}
|
||||
|
||||
// Run decides whether to run a fresh health probe or serve the most recent
|
||||
// result. It serialises concurrent callers: at most one runner.RunHealthProbes
|
||||
// executes at a time and the rest call refresher.RefreshWireGuardStats and read
|
||||
// the snapshot it produced.
|
||||
//
|
||||
// Both calls run while the throttle's lock is held, so a slow probe blocks
|
||||
// other callers until it completes — that blocking is the single-flight
|
||||
// guarantee. ctx is forwarded to RunHealthProbes so a caller that gives up
|
||||
// cancels the in-flight probe (and any caller still queued on the lock falls
|
||||
// through quickly once it acquires it, since the probe ctx is already done).
|
||||
func (t *probeThrottle) Run(ctx context.Context, runner healthProbeRunner, refresher statsRefresher, waitForResult bool) {
|
||||
entered := time.Now()
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// A probe that finished after we entered ran while we were waiting on the
|
||||
// lock — i.e. a peer in the same burst already probed for us, so share its
|
||||
// result rather than launch another. This holds even when that probe
|
||||
// failed, so a failed probe doesn't make every waiter re-probe in turn.
|
||||
sharedRecentProbe := t.completedAt.After(entered)
|
||||
throttled := time.Since(t.lastOK) <= t.interval
|
||||
|
||||
if sharedRecentProbe || throttled {
|
||||
if err := refresher.RefreshWireGuardStats(); err != nil {
|
||||
log.Debugf("failed to refresh WireGuard stats: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
healthy := runner.RunHealthProbes(ctx, waitForResult)
|
||||
t.completedAt = time.Now()
|
||||
if healthy {
|
||||
t.lastOK = t.completedAt
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user