mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 23:36:39 +00:00
Compare commits
125 Commits
fix/login-
...
v0.65.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f53155562f | ||
|
|
edce11b34d | ||
|
|
841b2d26c6 | ||
|
|
d3eeb6d8ee | ||
|
|
7ebf37ef20 | ||
|
|
64b849c801 | ||
|
|
69d4b5d821 | ||
|
|
3dfa97dcbd | ||
|
|
1ddc9ce2bf | ||
|
|
2de1949018 | ||
|
|
fc88399c23 | ||
|
|
6981fdce7e | ||
|
|
08403f64aa | ||
|
|
391221a986 | ||
|
|
7bc85107eb | ||
|
|
3be16d19a0 | ||
|
|
af8f730bda | ||
|
|
c3f176f348 | ||
|
|
0119f3e9f4 | ||
|
|
1b96648d4d | ||
|
|
d2f9653cea | ||
|
|
194a986926 | ||
|
|
f7732557fa | ||
|
|
d488f58311 | ||
|
|
6fdc00ff41 | ||
|
|
b20d484972 | ||
|
|
8931293343 | ||
|
|
7b830d8f72 | ||
|
|
3a0cf230a1 | ||
|
|
0c990ab662 | ||
|
|
101c813e98 | ||
|
|
5333e55a81 | ||
|
|
81c11df103 | ||
|
|
f74bc48d16 | ||
|
|
0169e4540f | ||
|
|
cead3f38ee | ||
|
|
b55262d4a2 | ||
|
|
2248ff392f | ||
|
|
06966da012 | ||
|
|
d4f7df271a | ||
|
|
5299549eb6 | ||
|
|
7d791620a6 | ||
|
|
44ab454a13 | ||
|
|
11f50d6c38 | ||
|
|
05af39a69b | ||
|
|
074df56c3d | ||
|
|
2381e216e4 | ||
|
|
ded04b7627 | ||
|
|
67211010f7 | ||
|
|
c61568ceb4 | ||
|
|
737d6061bf | ||
|
|
ee3a67d2d8 | ||
|
|
1a32e4c223 | ||
|
|
269d5d1cba | ||
|
|
a1de2b8a98 | ||
|
|
d0221a3e72 | ||
|
|
8da23daae3 | ||
|
|
f86022eace | ||
|
|
ee54827f94 | ||
|
|
e908dea702 | ||
|
|
030650a905 | ||
|
|
e01998815e | ||
|
|
07e4a5a23c | ||
|
|
b3a2992a10 | ||
|
|
202fa47f2b | ||
|
|
4888021ba6 | ||
|
|
a0b0b664b6 | ||
|
|
50da5074e7 | ||
|
|
58daa674ef | ||
|
|
245481f33b | ||
|
|
b352ab84c0 | ||
|
|
3ce5d6a4f8 | ||
|
|
4c2eb2af73 | ||
|
|
daf1449174 | ||
|
|
1ff7abe909 | ||
|
|
067c77e49e | ||
|
|
291e640b28 | ||
|
|
efb954b7d6 | ||
|
|
cac9326d3d | ||
|
|
520d9c66cf | ||
|
|
ff10498a8b | ||
|
|
00b747ad5d | ||
|
|
d9118eb239 | ||
|
|
94de656fae | ||
|
|
37abab8b69 | ||
|
|
b12c084a50 | ||
|
|
394ad19507 | ||
|
|
614e7d5b90 | ||
|
|
f7967f9ae3 | ||
|
|
684fc0d2a2 | ||
|
|
0ad0c81899 | ||
|
|
e8863fbb55 | ||
|
|
9c9d8e17d7 | ||
|
|
fb71b0d04b | ||
|
|
ab7d6b2196 | ||
|
|
9c5b2575e3 | ||
|
|
00e2689ffb | ||
|
|
cf535f8c61 | ||
|
|
24df442198 | ||
|
|
8722b79799 | ||
|
|
afcdef6121 | ||
|
|
12a7fa24d7 | ||
|
|
6ff9aa0366 | ||
|
|
e586c20e36 | ||
|
|
5393ad948f | ||
|
|
20d6beff1b | ||
|
|
d35b7d675c | ||
|
|
f012fb8592 | ||
|
|
7142d45ef3 | ||
|
|
9bd578d4ea | ||
|
|
f022e34287 | ||
|
|
7bb4fc3450 | ||
|
|
07856f516c | ||
|
|
08b782d6ba | ||
|
|
80a312cc9c | ||
|
|
9ba067391f | ||
|
|
7ac65bf1ad | ||
|
|
2e9c316852 | ||
|
|
96cdd56902 | ||
|
|
9ed1437442 | ||
|
|
a8604ef51c | ||
|
|
d88e046d00 | ||
|
|
1d2c7776fd | ||
|
|
4035f07248 | ||
|
|
ef2721f4e1 |
@@ -1,15 +1,15 @@
|
||||
FROM golang:1.23-bullseye
|
||||
FROM golang:1.25-bookworm
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends\
|
||||
gettext-base=0.21-4 \
|
||||
iptables=1.8.7-1 \
|
||||
libgl1-mesa-dev=20.3.5-1 \
|
||||
xorg-dev=1:7.7+22 \
|
||||
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
|
||||
gettext-base=0.21-12 \
|
||||
iptables=1.8.9-2 \
|
||||
libgl1-mesa-dev=22.3.6-1+deb12u1 \
|
||||
xorg-dev=1:7.7+23 \
|
||||
libayatana-appindicator3-dev=0.5.92-1 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& go install -v golang.org/x/tools/gopls@v0.18.1
|
||||
&& go install -v golang.org/x/tools/gopls@latest
|
||||
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
||||
.env
|
||||
.env.*
|
||||
*.pem
|
||||
*.key
|
||||
*.crt
|
||||
*.p12
|
||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
||||
echo ""
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
@@ -39,11 +39,11 @@ jobs:
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
|
||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
|
||||
5
.github/workflows/golang-test-freebsd.yml
vendored
5
.github/workflows/golang-test-freebsd.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
release: "14.2"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
@@ -39,13 +39,12 @@ jobs:
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||
time go test -timeout 8m -failfast -v -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
time go test -timeout 1m -failfast ./client/iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./signal/...
|
||||
time go test -timeout 1m -failfast ./util/...
|
||||
time go test -timeout 1m -failfast ./version/...
|
||||
|
||||
65
.github/workflows/golang-test-linux.yml
vendored
65
.github/workflows/golang-test-linux.yml
vendored
@@ -97,6 +97,16 @@ jobs:
|
||||
working-directory: relay
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||
|
||||
- name: Build combined
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 go build .
|
||||
|
||||
- name: Build combined 386
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: combined
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
|
||||
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -144,7 +154,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -200,11 +210,11 @@ jobs:
|
||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||
-e CONTAINER=${CONTAINER} \
|
||||
golang:1.24-alpine \
|
||||
golang:1.25-alpine \
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
@@ -259,7 +269,54 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test ${{ matrix.raceFlag }} \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./relay/... ./shared/relay/...
|
||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||
|
||||
test_proxy:
|
||||
name: "Proxy / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test -timeout 10m -p 1 ./proxy/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
|
||||
11
.github/workflows/golangci-lint.yml
vendored
11
.github/workflows/golangci-lint.yml
vendored
@@ -19,8 +19,8 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
|
||||
skip: go.mod,go.sum
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -52,7 +52,10 @@ jobs:
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v4
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=12m --out-format colored-line-number
|
||||
skip-cache: true
|
||||
skip-save-cache: true
|
||||
cache-invalidation-interval: 0
|
||||
args: --timeout=12m
|
||||
|
||||
20
.github/workflows/release.yml
vendored
20
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.0"
|
||||
SIGN_PIPE_VER: "v0.1.1"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
pkg install -y git curl portlint go
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request'
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
@@ -176,6 +176,7 @@ jobs:
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
@@ -185,6 +186,19 @@ jobs:
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: Tag and push PR images (amd64 only)
|
||||
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
|
||||
run: |
|
||||
PR_TAG="pr-${{ github.event.pull_request.number }}"
|
||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||
grep '^ghcr.io/' | while read -r SRC; do
|
||||
IMG_NAME="${SRC%%:*}"
|
||||
DST="${IMG_NAME}:${PR_TAG}"
|
||||
echo "Tagging ${SRC} -> ${DST}"
|
||||
docker tag "$SRC" "$DST"
|
||||
docker push "$DST"
|
||||
done
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
||||
@@ -243,6 +243,7 @@ jobs:
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
sleep 30
|
||||
docker compose logs
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
|
||||
|
||||
|
||||
13
.github/workflows/wasm-build-validation.yml
vendored
13
.github/workflows/wasm-build-validation.yml
vendored
@@ -14,6 +14,9 @@ jobs:
|
||||
js_lint:
|
||||
name: "JS / Lint"
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GOOS: js
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -24,16 +27,14 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
skip-cache: true
|
||||
skip-pkg-cache: true
|
||||
skip-build-cache: true
|
||||
- name: Run golangci-lint for WASM
|
||||
run: |
|
||||
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
|
||||
skip-save-cache: true
|
||||
cache-invalidation-interval: 0
|
||||
working-directory: ./client
|
||||
continue-on-error: true
|
||||
|
||||
js_build:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
.run
|
||||
*.iml
|
||||
dist/
|
||||
!proxy/web/dist/
|
||||
bin/
|
||||
.env
|
||||
conf.json
|
||||
@@ -31,3 +32,4 @@ infrastructure_files/setup-*.env
|
||||
.DS_Store
|
||||
vendor/
|
||||
/netbird
|
||||
client/netbird-electron/
|
||||
|
||||
255
.golangci.yaml
255
.golangci.yaml
@@ -1,139 +1,124 @@
|
||||
run:
|
||||
# Timeout for analysis, e.g. 30s, 5m.
|
||||
# Default: 1m
|
||||
timeout: 6m
|
||||
|
||||
# This file contains only configs which differ from defaults.
|
||||
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: false
|
||||
|
||||
gosec:
|
||||
includes:
|
||||
- G101 # Look for hard coded credentials
|
||||
#- G102 # Bind to all interfaces
|
||||
- G103 # Audit the use of unsafe block
|
||||
- G104 # Audit errors not checked
|
||||
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
|
||||
#- G107 # Url provided to HTTP request as taint input
|
||||
- G108 # Profiling endpoint automatically exposed on /debug/pprof
|
||||
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
|
||||
- G110 # Potential DoS vulnerability via decompression bomb
|
||||
- G111 # Potential directory traversal
|
||||
#- G112 # Potential slowloris attack
|
||||
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
|
||||
#- G114 # Use of net/http serve function that has no support for setting timeouts
|
||||
- G201 # SQL query construction using format string
|
||||
- G202 # SQL query construction using string concatenation
|
||||
- G203 # Use of unescaped data in HTML templates
|
||||
#- G204 # Audit use of command execution
|
||||
- G301 # Poor file permissions used when creating a directory
|
||||
- G302 # Poor file permissions used with chmod
|
||||
- G303 # Creating tempfile using a predictable path
|
||||
- G304 # File path provided as taint input
|
||||
- G305 # File traversal when extracting zip/tar archive
|
||||
- G306 # Poor file permissions used when writing to a new file
|
||||
- G307 # Poor file permissions used when creating a file with os.Create
|
||||
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
|
||||
#- G402 # Look for bad TLS connection settings
|
||||
- G403 # Ensure minimum RSA key length of 2048 bits
|
||||
#- G404 # Insecure random number source (rand)
|
||||
#- G501 # Import blocklist: crypto/md5
|
||||
- G502 # Import blocklist: crypto/des
|
||||
- G503 # Import blocklist: crypto/rc4
|
||||
- G504 # Import blocklist: net/http/cgi
|
||||
#- G505 # Import blocklist: crypto/sha1
|
||||
- G601 # Implicit memory aliasing of items from a range statement
|
||||
- G602 # Slice access out of bounds
|
||||
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- commentFormatting
|
||||
- captLocal
|
||||
- deprecatedComment
|
||||
|
||||
govet:
|
||||
# Enable all analyzers.
|
||||
# Default: false
|
||||
enable-all: false
|
||||
enable:
|
||||
- nilness
|
||||
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
severity: warning
|
||||
disabled: false
|
||||
arguments:
|
||||
- "checkPrivateReceivers"
|
||||
- "sayRepetitiveInsteadOfStutters"
|
||||
tenv:
|
||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||
# Default: false
|
||||
all: true
|
||||
|
||||
version: "2"
|
||||
linters:
|
||||
disable-all: true
|
||||
default: none
|
||||
enable:
|
||||
## enabled by default
|
||||
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
|
||||
- gosimple # specializes in simplifying a code
|
||||
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||
- ineffassign # detects when assignments to existing variables are not used
|
||||
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
|
||||
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
|
||||
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unused # checks for unused constants, variables, functions and types
|
||||
## disable by default but the have interesting results so lets add them
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- dupword # dupword checks for duplicate words in the source code
|
||||
- durationcheck # durationcheck checks for two durations multiplied together
|
||||
- forbidigo # forbidigo forbids identifiers
|
||||
- gocritic # provides diagnostics that check for bugs, performance and style issues
|
||||
- gosec # inspects source code for security problems
|
||||
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
|
||||
- misspell # misspess finds commonly misspelled English words in comments
|
||||
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
- bodyclose
|
||||
- dupword
|
||||
- durationcheck
|
||||
- errcheck
|
||||
- forbidigo
|
||||
- gocritic
|
||||
- gosec
|
||||
- govet
|
||||
- ineffassign
|
||||
- mirror
|
||||
- misspell
|
||||
- nilerr
|
||||
- nilnil
|
||||
- predeclared
|
||||
- revive
|
||||
- sqlclosecheck
|
||||
- staticcheck
|
||||
- unused
|
||||
- wastedassign
|
||||
settings:
|
||||
errcheck:
|
||||
check-type-assertions: false
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- commentFormatting
|
||||
- captLocal
|
||||
- deprecatedComment
|
||||
gosec:
|
||||
includes:
|
||||
- G101
|
||||
- G103
|
||||
- G104
|
||||
- G106
|
||||
- G108
|
||||
- G109
|
||||
- G110
|
||||
- G111
|
||||
- G201
|
||||
- G202
|
||||
- G203
|
||||
- G301
|
||||
- G302
|
||||
- G303
|
||||
- G304
|
||||
- G305
|
||||
- G306
|
||||
- G307
|
||||
- G403
|
||||
- G502
|
||||
- G503
|
||||
- G504
|
||||
- G601
|
||||
- G602
|
||||
govet:
|
||||
enable:
|
||||
- nilness
|
||||
enable-all: false
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
arguments:
|
||||
- checkPrivateReceivers
|
||||
- sayRepetitiveInsteadOfStutters
|
||||
severity: warning
|
||||
disabled: false
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
rules:
|
||||
- linters:
|
||||
- forbidigo
|
||||
path: management/cmd/root\.go
|
||||
- linters:
|
||||
- forbidigo
|
||||
path: signal/cmd/root\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: sharedsock/filter\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: client/firewall/iptables/rule\.go
|
||||
- linters:
|
||||
- gosec
|
||||
- mirror
|
||||
path: test\.go
|
||||
- linters:
|
||||
- nilnil
|
||||
path: mock\.go
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: grpc.DialContext is deprecated
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: grpc.WithBlock is deprecated
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1001"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1008"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1012"
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
# Set to 0 to disable.
|
||||
# Default: 3
|
||||
max-same-issues: 5
|
||||
|
||||
exclude-rules:
|
||||
# allow fmt
|
||||
- path: management/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: signal/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: sharedsock/filter\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: client/firewall/iptables/rule\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: test\.go
|
||||
linters:
|
||||
- mirror
|
||||
- gosec
|
||||
- path: mock\.go
|
||||
linters:
|
||||
- nilnil
|
||||
# Exclude specific deprecation warnings for grpc methods
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.DialContext is deprecated"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.WithBlock is deprecated"
|
||||
formatters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
183
.goreleaser.yaml
183
.goreleaser.yaml
@@ -106,6 +106,26 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-server
|
||||
dir: combined
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-server
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-upload
|
||||
dir: upload-server
|
||||
env: [CGO_ENABLED=0]
|
||||
@@ -120,6 +140,20 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-proxy
|
||||
dir: proxy/cmd/proxy
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-proxy
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -520,6 +554,104 @@ dockers:
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
@@ -598,6 +730,18 @@ docker_manifests:
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
@@ -675,6 +819,43 @@ docker_manifests:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
@@ -713,8 +894,10 @@ checksum:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
13
README.md
13
README.md
@@ -38,6 +38,11 @@
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
@@ -55,8 +60,8 @@
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
### Self-Host NetBird (Video)
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
@@ -85,7 +90,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
@@ -98,7 +103,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.22.2
|
||||
FROM alpine:3.23.2
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -3,15 +3,7 @@ package android
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||
}
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
return true, err
|
||||
}
|
||||
@@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
@@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
var needsLogin bool
|
||||
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
|
||||
if err == nil {
|
||||
go urlOpener.OnLoginSuccess()
|
||||
}
|
||||
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
err, _ = authClient.Login(a.ctx, "", jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
go urlOpener.OnLoginSuccess()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||
return backoff.RetryNotify(
|
||||
bf,
|
||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -136,6 +134,7 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := server.ParseLogLevel(args[0])
|
||||
if level == proto.LogLevel_UNKNOWN {
|
||||
//nolint
|
||||
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
|
||||
}
|
||||
|
||||
@@ -220,21 +219,37 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||
cpuProfilingStarted := false
|
||||
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
|
||||
} else {
|
||||
cpuProfilingStarted = true
|
||||
defer func() {
|
||||
if cpuProfilingStarted {
|
||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
if cpuProfilingStarted {
|
||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||
} else {
|
||||
cpuProfilingStarted = false
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -301,25 +316,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context(), true)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -378,7 +374,8 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
||||
InternalConfig: config,
|
||||
StatusRecorder: recorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogFile: logFilePath,
|
||||
LogPath: logFilePath,
|
||||
CPUProfile: nil,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -81,6 +80,7 @@ var loginCmd = &cobra.Command{
|
||||
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
@@ -206,6 +206,7 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
|
||||
func switchProfile(ctx context.Context, profileName string, username string) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
@@ -275,18 +276,15 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, "", "")
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check login required: %v", err)
|
||||
}
|
||||
|
||||
jwtToken := ""
|
||||
@@ -298,23 +296,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if lastError != nil {
|
||||
return fmt.Errorf("login failed: %v", lastError)
|
||||
}
|
||||
|
||||
err, _ = authClient.Login(ctx, setupKey, jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
return fmt.Errorf("login failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -342,11 +326,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
defer c()
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build pprof
|
||||
// +build pprof
|
||||
|
||||
package cmd
|
||||
|
||||
|
||||
@@ -390,6 +390,7 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||
if err := validateDestinationPort(remoteAddr); err != nil {
|
||||
return fmt.Errorf("invalid remote address: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
@@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||
if err := validateDestinationPort(localAddr); err != nil {
|
||||
return fmt.Errorf("invalid local address: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
@@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateDestinationPort checks that the destination address has a valid port.
|
||||
// Port 0 is only valid for bind addresses (where the OS picks an available port),
|
||||
// not for destination addresses where we need to connect.
|
||||
func validateDestinationPort(addr string) error {
|
||||
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse address %s: %w", addr, err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
return fmt.Errorf("port 0 is not valid for destination address")
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return fmt.Errorf("port %d out of range (1-65535)", port)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||
|
||||
@@ -99,17 +99,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
statusOutputString = outputInformationHolder.FullDetailSummary()
|
||||
case jsonFlag:
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.JSON()
|
||||
case yamlFlag:
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.YAML()
|
||||
default:
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -124,6 +124,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
@@ -89,9 +90,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
@@ -100,6 +98,8 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
@@ -118,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -127,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||
@@ -216,6 +216,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -16,10 +16,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -29,6 +31,14 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// PeerConnStatus is a peer's connection status.
|
||||
type PeerConnStatus = peer.ConnStatus
|
||||
|
||||
const (
|
||||
// PeerStatusConnected indicates the peer is in connected state.
|
||||
PeerStatusConnected = peer.StatusConnected
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -38,6 +48,7 @@ type Client struct {
|
||||
setupKey string
|
||||
jwtToken string
|
||||
connect *internal.ConnectClient
|
||||
recorder *peer.Status
|
||||
}
|
||||
|
||||
// Options configures a new Client.
|
||||
@@ -66,6 +77,10 @@ type Options struct {
|
||||
StatePath string
|
||||
// DisableClientRoutes disables the client routes
|
||||
DisableClientRoutes bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -134,6 +149,8 @@ func New(opts Options) (*Client, error) {
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||
@@ -153,6 +170,7 @@ func New(opts Options) (*Client, error) {
|
||||
setupKey: opts.SetupKey,
|
||||
jwtToken: opts.JWTToken,
|
||||
config: config,
|
||||
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -161,26 +179,38 @@ func New(opts Options) (*Client, error) {
|
||||
func (c *Client) Start(startCtx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cancel != nil {
|
||||
if c.connect != nil {
|
||||
return ErrClientAlreadyStarted
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
|
||||
defer func() {
|
||||
if c.connect == nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
|
||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create auth client: %w", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan struct{})
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
if err := client.Run(run, ""); err != nil {
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
@@ -197,6 +227,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
c.cancel = cancel
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -211,17 +242,23 @@ func (c *Client) Stop(ctx context.Context) error {
|
||||
return ErrClientNotStarted
|
||||
}
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
c.cancel = nil
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
connect := c.connect
|
||||
go func() {
|
||||
done <- c.connect.Stop()
|
||||
done <- connect.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancel = nil
|
||||
c.connect = nil
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
c.cancel = nil
|
||||
c.connect = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
@@ -315,6 +352,57 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
_ = engine.RunHealthProbes(false)
|
||||
}
|
||||
}
|
||||
|
||||
return c.recorder.GetFullStatus(), nil
|
||||
}
|
||||
|
||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncResp, err := engine.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get sync response: %w", err)
|
||||
}
|
||||
|
||||
return syncResp, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the client and its components.
|
||||
func (c *Client) SetLogLevel(levelStr string) error {
|
||||
level, err := logrus.ParseLevel(levelStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse log level: %w", err)
|
||||
}
|
||||
|
||||
logrus.SetLevel(level)
|
||||
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect != nil {
|
||||
connect.SetLogLevel(level)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
|
||||
@@ -386,11 +386,8 @@ func (m *aclManager) updateState() {
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
if ip.IsUnspecified() {
|
||||
matchByIP = false
|
||||
}
|
||||
matchByIP := !ip.IsUnspecified()
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
|
||||
@@ -83,6 +83,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
return fmt.Errorf("init notrack chain: %w", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -177,6 +181,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
@@ -277,6 +285,125 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
tableRaw = "raw"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
|
||||
return fmt.Errorf("add output sport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
|
||||
return fmt.Errorf("add output dport notrack rule: %w", err)
|
||||
}
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
|
||||
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
|
||||
}
|
||||
|
||||
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
|
||||
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
|
||||
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChain() error {
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
log.Debugf("cleanup notrack chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("create chain: %w", err)
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||
log.Debugf("delete output jump rule: %v", delErr)
|
||||
}
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNoTrackChain() error {
|
||||
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
|
||||
return fmt.Errorf("clear and delete chain: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
t.Logf(" [%d] %s", i, rule)
|
||||
}
|
||||
|
||||
var denyRuleIndex, acceptRuleIndex int = -1, -1
|
||||
var denyRuleIndex, acceptRuleIndex = -1, -1
|
||||
for i, rule := range rules {
|
||||
if strings.Contains(rule, "DROP") {
|
||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||
|
||||
@@ -168,6 +168,10 @@ type Manager interface {
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -48,8 +49,10 @@ type Manager struct {
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
@@ -91,6 +94,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
return fmt.Errorf("init notrack chains: %w", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
@@ -288,7 +295,15 @@ func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclManager.Flush()
|
||||
if err := m.aclManager.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshNoTrackChains(); err != nil {
|
||||
log.Errorf("failed to refresh notrack chains: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
@@ -331,6 +346,176 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
)
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
|
||||
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
|
||||
//
|
||||
// Traffic flows that need NOTRACK:
|
||||
//
|
||||
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
|
||||
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
|
||||
// Matched by: sport=wgPort
|
||||
//
|
||||
// 2. Egress: Proxy -> WireGuard (via raw socket)
|
||||
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 3. Ingress: Packets to WireGuard
|
||||
// dst=127.0.0.1:wgPort
|
||||
// Matched by: dport=wgPort
|
||||
//
|
||||
// 4. Ingress: Packets to proxy (after eBPF rewrite)
|
||||
// dst=127.0.0.1:proxyPort
|
||||
// Matched by: dport=proxyPort
|
||||
//
|
||||
// Rules are cleaned up when the firewall manager is closed.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
|
||||
return fmt.Errorf("notrack chains not initialized")
|
||||
}
|
||||
|
||||
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
|
||||
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
|
||||
loopback := []byte{127, 0, 0, 1}
|
||||
|
||||
// Egress rules: match outgoing loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackOutputChain.Table,
|
||||
Chain: m.notrackOutputChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
// Ingress rules: match incoming loopback UDP packets
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.notrackPreroutingChain.Table,
|
||||
Chain: m.notrackPreroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
|
||||
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
|
||||
&expr.Counter{},
|
||||
&expr.Notrack{},
|
||||
},
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush notrack rules: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
|
||||
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawOutput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawPrerouting,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRaw,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush chain creation: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) refreshNoTrackChains() error {
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, c := range chains {
|
||||
if c.Table.Name != tableName {
|
||||
continue
|
||||
}
|
||||
switch c.Name {
|
||||
case chainNameRawOutput:
|
||||
m.notrackOutputChain = c
|
||||
case chainNameRawPrerouting:
|
||||
m.notrackPreroutingChain = c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
t.Logf("Found %d rules in nftables chain", len(rules))
|
||||
|
||||
// Find the accept and deny rules and verify deny comes before accept
|
||||
var acceptRuleIndex, denyRuleIndex int = -1, -1
|
||||
var acceptRuleIndex, denyRuleIndex = -1, -1
|
||||
for i, rule := range rules {
|
||||
hasAcceptHTTPSet := false
|
||||
hasDenyHTTPSet := false
|
||||
@@ -208,11 +208,13 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
for _, e := range rule.Exprs {
|
||||
// Check for set lookup
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
if lookup.SetName == "accept-http" {
|
||||
switch lookup.SetName {
|
||||
case "accept-http":
|
||||
hasAcceptHTTPSet = true
|
||||
} else if lookup.SetName == "deny-http" {
|
||||
case "deny-http":
|
||||
hasDenyHTTPSet = true
|
||||
}
|
||||
|
||||
}
|
||||
// Check for port 80
|
||||
if cmp, ok := e.(*expr.Cmp); ok {
|
||||
@@ -222,9 +224,10 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
}
|
||||
// Check for verdict
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
if verdict.Kind == expr.VerdictAccept {
|
||||
switch verdict.Kind {
|
||||
case expr.VerdictAccept:
|
||||
action = "ACCEPT"
|
||||
} else if verdict.Kind == expr.VerdictDrop {
|
||||
case expr.VerdictDrop:
|
||||
action = "DROP"
|
||||
}
|
||||
}
|
||||
@@ -386,6 +389,97 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
const octet2Count = 25
|
||||
const octet3Count = 255
|
||||
prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
|
||||
for i := 1; i < octet2Count; i++ {
|
||||
for j := 1; j < octet3Count; j++ {
|
||||
addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
|
||||
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
||||
}
|
||||
}
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
prefixes,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||
t.Helper()
|
||||
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||
|
||||
@@ -48,9 +48,11 @@ const (
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||
maxPrefixesSet = 1500
|
||||
refreshRulesMapError = "refresh rules map: %w"
|
||||
)
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||
@@ -481,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
if nftRule.Handle == 0 {
|
||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
@@ -513,16 +520,35 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.AddSet(nfset, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
nElements := len(elements)
|
||||
|
||||
maxElements := maxPrefixesSet * 2
|
||||
initialElements := elements[:min(maxElements, nElements)]
|
||||
|
||||
if err := r.conn.AddSet(nfset, initialElements); err != nil {
|
||||
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
|
||||
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
var subEnd int
|
||||
for subStart := maxElements; subStart < nElements; subStart += maxElements {
|
||||
subEnd = min(subStart+maxElements, nElements)
|
||||
subElement := elements[subStart:subEnd]
|
||||
nSubPrefixes := len(subElement) / 2
|
||||
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
|
||||
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||
}
|
||||
|
||||
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
@@ -639,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
// TODO: rollback ipset counter
|
||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
||||
r.rollbackRules(pair)
|
||||
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||
func (r *router) rollbackRules(pair firewall.RouterPair) {
|
||||
keys := []string{
|
||||
firewall.GenKey(firewall.ForwardingFormat, pair),
|
||||
firewall.GenKey(firewall.PreroutingFormat, pair),
|
||||
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
||||
}
|
||||
for _, key := range keys {
|
||||
rule, ok := r.rules[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||
}
|
||||
delete(r.rules, key)
|
||||
}
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
@@ -907,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
rule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1308,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||
}
|
||||
|
||||
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||
// counters will be off until the next successful removal or refresh cycle.
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
// TODO: rollback set counter
|
||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||
}
|
||||
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
rule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
|
||||
// (e.g. from failed flushes) and updates handles for all existing rules.
|
||||
func (r *router) refreshRulesMap() error {
|
||||
var merr *multierror.Error
|
||||
newRules := make(map[string]*nftables.Rule)
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list rules: %w", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||
// preserve existing entries for this chain since we can't verify their state
|
||||
for k, v := range r.rules {
|
||||
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||
newRules[k] = v
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
newRules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
r.rules = newRules
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
@@ -1608,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
var needsFlush bool
|
||||
|
||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
if dnatRule.Handle == 0 {
|
||||
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||
if err := r.conn.DelRule(masqRule); err != nil {
|
||||
if masqRule.Handle == 0 {
|
||||
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
if needsFlush {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
}
|
||||
|
||||
if merr == nil {
|
||||
@@ -1736,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -719,3 +720,137 @@ func deleteWorkTable() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Add a real rule to the kernel
|
||||
ruleKey, err := r.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
firewall.ProtocolTCP,
|
||||
nil,
|
||||
&firewall.Port{Values: []uint16{80}},
|
||||
firewall.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||
})
|
||||
|
||||
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||
staleKey := "stale-rule-that-does-not-exist"
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
}
|
||||
|
||||
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
|
||||
|
||||
err = r.refreshRulesMap()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
|
||||
|
||||
realRule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "real rule should still exist after refresh")
|
||||
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||
}
|
||||
|
||||
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Inject a stale entry with Handle=0
|
||||
staleKey := "stale-route-rule"
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
}
|
||||
|
||||
// DeleteRouteRule should not return an error for stale handles
|
||||
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||
}
|
||||
|
||||
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "staletest",
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Masquerade: true,
|
||||
}
|
||||
|
||||
rtr := manager.router
|
||||
|
||||
// First add succeeds
|
||||
err = rtr.AddNatRule(pair)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, rtr.RemoveNatRule(pair))
|
||||
})
|
||||
|
||||
// Corrupt the handle to simulate stale state
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
|
||||
// Adding the same rule again should succeed despite stale handles
|
||||
err = rtr.AddNatRule(pair)
|
||||
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
|
||||
|
||||
// Verify rules exist in kernel
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err)
|
||||
|
||||
found := 0
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
found++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||
}
|
||||
|
||||
@@ -3,12 +3,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
m.resetState()
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
m.resetState()
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
|
||||
@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
||||
return t.tombstone.Load()
|
||||
}
|
||||
|
||||
// IsSupersededBy returns true if this connection should be replaced by a new one
|
||||
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
||||
// connections are superseded by a pure SYN (a new connection attempt for the same
|
||||
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
||||
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
||||
if t.tombstone.Load() {
|
||||
return true
|
||||
}
|
||||
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
||||
}
|
||||
|
||||
// SetTombstone safely marks the connection for deletion
|
||||
func (t *TCPConnTrack) SetTombstone() {
|
||||
t.tombstone.Store(true)
|
||||
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
if exists && !conn.IsSupersededBy(flags) {
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.IsTombstone() {
|
||||
if !exists || conn.IsSupersededBy(flags) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
||||
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
||||
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
||||
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
||||
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
||||
func TestTCPPortReuseTombstone(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and gracefully close a connection (server-initiated close)
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Server sends FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Client sends FIN-ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
|
||||
// Server sends final ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Connection should be tombstoned
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn, "old connection should still be in map")
|
||||
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
||||
|
||||
// Now reuse the same port for a new connection
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
// The old tombstoned entry should be replaced with a new one
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn, "new connection should exist")
|
||||
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||
|
||||
// SYN-ACK for the new connection should be valid
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||
|
||||
// Data transfer should work
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
||||
require.True(t, valid, "data should be allowed on new connection")
|
||||
})
|
||||
|
||||
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and RST a connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
||||
|
||||
// Reuse the same port
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn)
|
||||
require.False(t, newConn.IsTombstone())
|
||||
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
||||
})
|
||||
|
||||
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
clientIP := srcIP
|
||||
serverIP := dstIP
|
||||
clientPort := srcPort
|
||||
serverPort := dstPort
|
||||
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
||||
|
||||
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Server-initiated close to reach Closed/tombstoned:
|
||||
// Server FIN (opposite dir) → CloseWait
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
// Client FIN-ACK (same dir as conn) → LastAck
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
// Server final ACK (opposite dir) → Closed → tombstoned
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
|
||||
require.True(t, conn.IsTombstone())
|
||||
|
||||
// New inbound connection on same ports
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn)
|
||||
require.False(t, newConn.IsTombstone())
|
||||
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||
|
||||
// Complete handshake: server SYN-ACK, then client ACK
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||
})
|
||||
|
||||
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.True(t, conn.IsTombstone())
|
||||
|
||||
// Late ACK should be rejected (tombstoned)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
||||
|
||||
// Late outbound ACK should not create a new connection (not a SYN)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPPortReuseTimeWait(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Active close: client (outbound initiator) sends FIN first
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Server ACKs the FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Server sends its own FIN
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
||||
|
||||
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
||||
newConn := tracker.connections[key]
|
||||
require.NotNil(t, newConn, "new connection should exist")
|
||||
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
||||
|
||||
// SYN-ACK for new connection should be valid
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
||||
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||
})
|
||||
|
||||
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish outbound connection and close via active close → TIME-WAIT
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
||||
// so the filter falls through to ACL check + TrackInbound (which creates
|
||||
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
||||
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
||||
|
||||
// Simulate what the filter does next: TrackInbound via the normal path
|
||||
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
||||
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
||||
newConn := tracker.connections[invertedKey]
|
||||
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
||||
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||
require.False(t, newConn.IsTombstone())
|
||||
})
|
||||
|
||||
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Establish and active close → TIME-WAIT
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Late ACK retransmits during TIME-WAIT should still be accepted
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPTimeoutHandling(t *testing.T) {
|
||||
// Create tracker with a very short timeout for testing
|
||||
shortTimeout := 100 * time.Millisecond
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -12,11 +13,13 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
@@ -24,12 +27,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
layerTypeAll = 0
|
||||
layerTypeAll = 255
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
@@ -89,6 +93,7 @@ type Manager struct {
|
||||
incomingDenyRules map[netip.Addr]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
@@ -262,10 +268,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse wireguard network: %w", err)
|
||||
}
|
||||
wgPrefix := iface.Address().Network
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
rule, err := m.addRouteFiltering(
|
||||
@@ -439,19 +442,7 @@ func (m *Manager) AddPeerFiltering(
|
||||
r.sPort = sPort
|
||||
r.dPort = dPort
|
||||
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
r.protoLayer = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
r.protoLayer = layers.LayerTypeUDP
|
||||
case firewall.ProtocolICMP:
|
||||
r.protoLayer = layers.LayerTypeICMPv4
|
||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
||||
r.protoLayer = layers.LayerTypeICMPv6
|
||||
}
|
||||
case firewall.ProtocolALL:
|
||||
r.protoLayer = layerTypeAll
|
||||
}
|
||||
r.protoLayer = protoToLayer(proto, r.ipLayer)
|
||||
|
||||
m.mutex.Lock()
|
||||
var targetMap map[netip.Addr]RuleSet
|
||||
@@ -495,17 +486,22 @@ func (m *Manager) addRouteFiltering(
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||
return existingRule, nil
|
||||
}
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
id: string(ruleKey),
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
if destination.IsPrefix() {
|
||||
rule.destinations = []netip.Prefix{destination.Prefix}
|
||||
@@ -513,6 +509,7 @@ func (m *Manager) addRouteFiltering(
|
||||
|
||||
m.routeRules = append(m.routeRules, &rule)
|
||||
m.routeRules.Sort()
|
||||
m.routeRulesMap[ruleKey] = &rule
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
@@ -529,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
ruleID := rule.ID()
|
||||
ruleKey := nbid.RuleID(rule.ID())
|
||||
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||
}
|
||||
|
||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||
return r.id == ruleID
|
||||
return r.id == string(ruleKey)
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||
}
|
||||
|
||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||
delete(m.routeRulesMap, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -584,6 +586,48 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// resetState clears all firewall rules and closes connection trackers.
|
||||
// Must be called with m.mutex held.
|
||||
func (m *Manager) resetState() {
|
||||
maps.Clear(m.outgoingRules)
|
||||
maps.Clear(m.incomingDenyRules)
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
@@ -795,7 +839,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
|
||||
var sum uint32 = pseudoSum
|
||||
var sum = pseudoSum
|
||||
for i := 0; i < tcpLength-1; i += 2 {
|
||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||
}
|
||||
@@ -945,7 +989,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
|
||||
if blocked {
|
||||
_, pnum := getProtocolFromPacket(d)
|
||||
pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
@@ -1010,20 +1054,22 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return false
|
||||
}
|
||||
|
||||
proto, pnum := getProtocolFromPacket(d)
|
||||
protoLayer := d.decoded[1]
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
|
||||
if !pass {
|
||||
proto := getProtocolFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: pnum,
|
||||
Protocol: proto,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
@@ -1052,16 +1098,33 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return true
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
|
||||
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
return layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
return layers.LayerTypeUDP
|
||||
case firewall.ProtocolICMP:
|
||||
if ipLayer == layers.LayerTypeIPv6 {
|
||||
return layers.LayerTypeICMPv6
|
||||
}
|
||||
return layers.LayerTypeICMPv4
|
||||
case firewall.ProtocolALL:
|
||||
return layerTypeAll
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return firewall.ProtocolTCP, nftypes.TCP
|
||||
return nftypes.TCP
|
||||
case layers.LayerTypeUDP:
|
||||
return firewall.ProtocolUDP, nftypes.UDP
|
||||
return nftypes.UDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return firewall.ProtocolICMP, nftypes.ICMP
|
||||
return nftypes.ICMP
|
||||
default:
|
||||
return firewall.ProtocolALL, nftypes.ProtocolUnknown
|
||||
return nftypes.ProtocolUnknown
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1233,19 +1296,30 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
}
|
||||
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
|
||||
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
||||
// TODO: handle ipv6 vs ipv4 icmp rules
|
||||
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
|
||||
return false
|
||||
}
|
||||
|
||||
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
destMatched := false
|
||||
for _, dst := range rule.destinations {
|
||||
if dst.Contains(dstAddr) {
|
||||
@@ -1264,21 +1338,8 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sourceMatched {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
return false
|
||||
}
|
||||
|
||||
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return sourceMatched
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
|
||||
@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
for _, tc := range cases {
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
|
||||
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
})
|
||||
}
|
||||
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr(p.srcIP)
|
||||
dstIP := netip.MustParseAddr(p.dstIP)
|
||||
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
|
||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||
}
|
||||
})
|
||||
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
|
||||
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||
|
||||
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now the packet should be allowed
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
|
||||
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
||||
// filtering rule twice returns the same rule ID (idempotent behavior).
|
||||
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.1.0/24"),
|
||||
netip.MustParsePrefix("100.64.2.0/24"),
|
||||
}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add rule first time
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
// Add the same rule again
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
// These should be the same (idempotent) like nftables/iptables implementations
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, ruleCount,
|
||||
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
||||
}
|
||||
|
||||
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
||||
// different parameters get distinct IDs.
|
||||
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
|
||||
// Add first rule
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add different rule (different destination)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-2"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||
"Different rules should have different IDs")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
||||
}
|
||||
|
||||
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
||||
// rule during a network map update does not disrupt existing traffic.
|
||||
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
require.True(t, pass, "Traffic should pass with rule in place")
|
||||
|
||||
// Re-add same rule (simulates network map update)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||
// would remove the only matching rule and cause a traffic gap.
|
||||
if rule1.ID() != rule2.ID() {
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
assert.True(t, passAfter,
|
||||
"Traffic should still pass after rule update - no gap should occur")
|
||||
}
|
||||
|
||||
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||
// returns the same rule without duplicating.
|
||||
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Call blockInvalidRouted directly multiple times
|
||||
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule3)
|
||||
|
||||
// All should return the same rule
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||
|
||||
// Should have exactly 1 route rule
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
||||
|
||||
// Verify the rule blocks traffic to the WG network
|
||||
srcIP := netip.MustParseAddr("10.0.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.50")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
||||
}
|
||||
|
||||
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
||||
// EnableRouting multiple times (as happens on each route update) does not
|
||||
// accumulate duplicate block rules in the routeRules slice.
|
||||
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Call EnableRouting multiple times (simulating repeated route updates)
|
||||
for i := 0; i < 5; i++ {
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
}
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, ruleCount,
|
||||
"Repeated EnableRouting should not accumulate block rules")
|
||||
}
|
||||
|
||||
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
||||
// rule multiple times does not create duplicate entries.
|
||||
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Simulate 5 network map updates with the same route rule
|
||||
for i := 0; i < 5; i++ {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
}
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, ruleCount,
|
||||
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
||||
}
|
||||
|
||||
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
||||
// after adding it multiple times works correctly.
|
||||
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
manager := setupTestManager(t)
|
||||
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add same rule twice
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||
|
||||
// Delete using first reference
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify traffic no longer passes
|
||||
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
||||
}
|
||||
|
||||
func setupTestManager(t *testing.T) *Manager {
|
||||
t.Helper()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
return manager
|
||||
}
|
||||
@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
// to the deny map and can be cleanly deleted without leaving orphans.
|
||||
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Add multiple deny rules for different ports
|
||||
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||
|
||||
// Delete the first deny rule
|
||||
err = m.DeletePeerRule(rule1[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount = len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||
|
||||
// Delete the second deny rule
|
||||
err = m.DeletePeerRule(rule2[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, exists := m.incomingDenyRules[addr]
|
||||
m.mutex.RUnlock()
|
||||
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||
}
|
||||
|
||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||
// peer rules (simulating network map updates) does not leak rules in the maps.
|
||||
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Simulate 10 network map updates: add rule, delete old, add new
|
||||
for i := 0; i < 10; i++ {
|
||||
// Add a deny rule
|
||||
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add an allow rule
|
||||
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete them (simulating ACL manager cleanup)
|
||||
for _, r := range rules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
for _, r := range allowRules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
||||
}
|
||||
|
||||
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
||||
// IP are stored in separate maps and don't interfere with each other.
|
||||
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
|
||||
// Add allow rule for port 80
|
||||
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add deny rule for port 22
|
||||
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
m.mutex.RLock()
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||
|
||||
// Delete allow rule should not affect deny rule
|
||||
err = m.DeletePeerRule(allowRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||
|
||||
// Delete deny rule
|
||||
err = m.DeletePeerRule(denyRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, denyExists := m.incomingDenyRules[addr]
|
||||
_, allowExists := m.incomingRules[addr]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.False(t, denyExists, "Deny rules should be empty")
|
||||
require.False(t, allowExists, "Allow rules should be empty")
|
||||
}
|
||||
|
||||
func TestManagerReset(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
@@ -767,9 +919,9 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||
@@ -784,8 +936,8 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all original prefixes are still included
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||
|
||||
@@ -793,8 +945,8 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
|
||||
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||
@@ -922,7 +1074,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
for _, tc := range testCases {
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
@@ -16,7 +17,7 @@ type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu uint32
|
||||
mtu atomic.Uint32
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
@@ -28,7 +29,7 @@ func (e *endpoint) IsAttached() bool {
|
||||
}
|
||||
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
return e.mtu
|
||||
return e.mtu.Load()
|
||||
}
|
||||
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
@@ -82,6 +83,22 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *endpoint) Close() {
|
||||
// Endpoint cleanup - nothing to do as device is managed externally
|
||||
}
|
||||
|
||||
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
|
||||
// Link address is not used for this endpoint type
|
||||
}
|
||||
|
||||
func (e *endpoint) SetMTU(mtu uint32) {
|
||||
e.mtu.Store(mtu)
|
||||
}
|
||||
|
||||
func (e *endpoint) SetOnCloseAction(func()) {
|
||||
// No action needed on close
|
||||
}
|
||||
|
||||
type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -35,14 +36,16 @@ type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
@@ -60,8 +63,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
endpoint.mtu.Store(uint32(mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("create NIC: %v", err)
|
||||
@@ -103,15 +106,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
pingSemaphore: make(chan struct{}, 3),
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@@ -129,6 +133,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
|
||||
f.checkICMPCapability()
|
||||
|
||||
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||
return f, nil
|
||||
}
|
||||
@@ -198,3 +204,24 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||
func (f *Forwarder) checkICMPCapability() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.hasRawICMPAccess = false
|
||||
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
||||
}
|
||||
|
||||
f.hasRawICMPAccess = true
|
||||
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
||||
}
|
||||
|
||||
@@ -2,8 +2,11 @@ package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -14,30 +17,95 @@ import (
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
icmpType := uint8(icmpHdr.Type())
|
||||
icmpCode := uint8(icmpHdr.Code())
|
||||
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||
// dont process our own replies
|
||||
return true
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
// For Echo Requests, send and wait for response
|
||||
if icmpHdr.Type() == header.ICMPv4Echo {
|
||||
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
|
||||
if !f.hasRawICMPAccess {
|
||||
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
||||
return false
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
|
||||
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
||||
select {
|
||||
case f.pingSemaphore <- struct{}{}:
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
||||
rxBytes := pkt.Size()
|
||||
|
||||
go func() {
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
|
||||
if f.hasRawICMPAccess {
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
} else {
|
||||
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}
|
||||
}()
|
||||
default:
|
||||
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||
// The caller is responsible for closing the returned connection.
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||
}
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
sendTime := time.Now()
|
||||
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
@@ -45,38 +113,22 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
}
|
||||
}()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
txBytes := f.handleEchoResponse(conn, id)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
rxBytes := pkt.Size()
|
||||
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
response := make([]byte, f.endpoint.mtu.Load())
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
@@ -85,31 +137,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
|
||||
return 0
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return len(fullPacket)
|
||||
return f.injectICMPReply(id, response[:n])
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
@@ -152,3 +180,95 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
|
||||
// This is used as a fallback when raw socket access is not available.
|
||||
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
||||
|
||||
pingStart := time.Now()
|
||||
if err := cmd.Run(); err != nil {
|
||||
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
|
||||
icmpType, icmpCode, err)
|
||||
return
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeEchoReply(id, icmpData)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// buildPingCommand creates a platform-specific ping command.
|
||||
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
||||
timeoutSec := int(timeout.Seconds())
|
||||
if timeoutSec < 1 {
|
||||
timeoutSec = 1
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "darwin", "ios":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "freebsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "openbsd", "netbsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "windows":
|
||||
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
default:
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
||||
}
|
||||
}
|
||||
|
||||
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
|
||||
// Returns the size of the injected packet.
|
||||
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
||||
replyICMP := make([]byte, len(icmpData))
|
||||
copy(replyICMP, icmpData)
|
||||
|
||||
replyICMPHdr := header.ICMPv4(replyICMP)
|
||||
replyICMPHdr.SetType(header.ICMPv4EchoReply)
|
||||
replyICMPHdr.SetChecksum(0)
|
||||
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
|
||||
|
||||
return f.injectICMPReply(id, replyICMP)
|
||||
}
|
||||
|
||||
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
|
||||
// Returns the total size of the injected packet, or 0 if injection failed.
|
||||
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, icmpPayload...)
|
||||
|
||||
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
|
||||
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -131,10 +132,10 @@ func (f *udpForwarder) cleanup() {
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if f.ctx.Err() != nil {
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
id := r.ID()
|
||||
@@ -144,7 +145,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
@@ -162,7 +163,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if err != nil {
|
||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
@@ -173,10 +174,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
inConn := gonet.NewUDPConn(&wq, ep)
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
|
||||
pConn := &udpPacketConn{
|
||||
@@ -199,7 +200,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
@@ -208,6 +209,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
@@ -348,7 +350,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
|
||||
@@ -130,6 +130,7 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
|
||||
|
||||
@@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
|
||||
@@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -16,9 +18,18 @@ const (
|
||||
maxBatchSize = 1024 * 16
|
||||
maxMessageSize = 1024 * 2
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
logChannelSize = 1000
|
||||
defaultLogChanSize = 1000
|
||||
)
|
||||
|
||||
func getLogChannelSize() int {
|
||||
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultLogChanSize
|
||||
}
|
||||
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
@@ -69,7 +80,7 @@ type Logger struct {
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
msgChannel: make(chan logMessage, logChannelSize),
|
||||
msgChannel: make(chan logMessage, getLogChannelSize()),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
@@ -168,6 +179,15 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
|
||||
@@ -234,9 +234,10 @@ func TestInboundPortDNATNegative(t *testing.T) {
|
||||
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
|
||||
|
||||
d = parsePacket(t, packet)
|
||||
if tc.protocol == layers.IPProtocolTCP {
|
||||
switch tc.protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
|
||||
} else if tc.protocol == layers.IPProtocolUDP {
|
||||
case layers.IPProtocolUDP:
|
||||
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -34,7 +34,7 @@ type RouteRule struct {
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
destinations []netip.Prefix
|
||||
proto firewall.Protocol
|
||||
protoLayer gopacket.LayerType
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
|
||||
@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||
proto, _ := getProtocolFromPacket(d)
|
||||
protoLayer := d.decoded[1]
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
|
||||
|
||||
strId := string(id)
|
||||
if id == nil {
|
||||
|
||||
169
client/iface/bind/dual_stack_conn.go
Normal file
169
client/iface/bind/dual_stack_conn.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoIPv4Conn = errors.New("no IPv4 connection available")
|
||||
errNoIPv6Conn = errors.New("no IPv6 connection available")
|
||||
errInvalidAddr = errors.New("invalid address type")
|
||||
)
|
||||
|
||||
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
|
||||
// to the appropriate connection based on the destination address.
|
||||
// ReadFrom is not used in the hot path - ICEBind receives packets via
|
||||
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
|
||||
type DualStackPacketConn struct {
|
||||
ipv4Conn net.PacketConn
|
||||
ipv6Conn net.PacketConn
|
||||
|
||||
readFromWarn sync.Once
|
||||
}
|
||||
|
||||
// NewDualStackPacketConn creates a new dual-stack packet connection.
|
||||
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
|
||||
return &DualStackPacketConn{
|
||||
ipv4Conn: ipv4Conn,
|
||||
ipv6Conn: ipv6Conn,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom reads from the available connection (preferring IPv4).
|
||||
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
|
||||
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
|
||||
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
|
||||
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
|
||||
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
|
||||
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
d.readFromWarn.Do(func() {
|
||||
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
|
||||
})
|
||||
|
||||
if d.ipv4Conn != nil {
|
||||
return d.ipv4Conn.ReadFrom(b)
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
return d.ipv6Conn.ReadFrom(b)
|
||||
}
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
|
||||
// WriteTo writes to the appropriate connection based on the address type.
|
||||
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp",
|
||||
Addr: addr,
|
||||
Err: errInvalidAddr,
|
||||
}
|
||||
}
|
||||
|
||||
if udpAddr.IP.To4() == nil {
|
||||
if d.ipv6Conn != nil {
|
||||
return d.ipv6Conn.WriteTo(b, addr)
|
||||
}
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp6",
|
||||
Addr: addr,
|
||||
Err: errNoIPv6Conn,
|
||||
}
|
||||
}
|
||||
|
||||
if d.ipv4Conn != nil {
|
||||
return d.ipv4Conn.WriteTo(b, addr)
|
||||
}
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp4",
|
||||
Addr: addr,
|
||||
Err: errNoIPv4Conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes both connections.
|
||||
func (d *DualStackPacketConn) Close() error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address of the IPv4 connection if available,
|
||||
// otherwise the IPv6 connection.
|
||||
func (d *DualStackPacketConn) LocalAddr() net.Addr {
|
||||
if d.ipv4Conn != nil {
|
||||
return d.ipv4Conn.LocalAddr()
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
return d.ipv6Conn.LocalAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline sets the deadline for both connections.
|
||||
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.SetDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.SetDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline for both connections.
|
||||
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline for both connections.
|
||||
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
119
client/iface/bind/dual_stack_conn_bench_test.go
Normal file
119
client/iface/bind/dual_stack_conn_bench_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
|
||||
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
|
||||
payload = make([]byte, 1200)
|
||||
)
|
||||
|
||||
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
|
||||
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = conn.WriteTo(payload, ipv4Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
|
||||
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(conn, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, ipv4Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
|
||||
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(nil, conn)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, ipv6Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
|
||||
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn4.Close()
|
||||
|
||||
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer conn6.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(conn4, conn6)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, ipv4Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
|
||||
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn4.Close()
|
||||
|
||||
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer conn6.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(conn4, conn6)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, ipv6Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
|
||||
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn4.Close()
|
||||
|
||||
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
b.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer conn6.Close()
|
||||
|
||||
ds := NewDualStackPacketConn(conn4, conn6)
|
||||
addrs := []net.Addr{ipv4Addr, ipv6Addr}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ds.WriteTo(payload, addrs[i&1])
|
||||
}
|
||||
}
|
||||
191
client/iface/bind/dual_stack_conn_test.go
Normal file
191
client/iface/bind/dual_stack_conn_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
|
||||
ipv4Conn := &mockPacketConn{network: "udp4"}
|
||||
ipv6Conn := &mockPacketConn{network: "udp6"}
|
||||
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
addr *net.UDPAddr
|
||||
wantSocket string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 address",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
|
||||
wantSocket: "udp4",
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
|
||||
wantSocket: "udp6",
|
||||
},
|
||||
{
|
||||
name: "IPv4-mapped IPv6 goes to IPv4",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
|
||||
wantSocket: "udp4",
|
||||
},
|
||||
{
|
||||
name: "IPv4 loopback",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
|
||||
wantSocket: "udp4",
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
|
||||
wantSocket: "udp6",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ipv4Conn.writeCount = 0
|
||||
ipv6Conn.writeCount = 0
|
||||
|
||||
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 4, n)
|
||||
|
||||
if tt.wantSocket == "udp4" {
|
||||
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
|
||||
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
|
||||
} else {
|
||||
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
|
||||
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
|
||||
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
|
||||
|
||||
// IPv4 works
|
||||
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
|
||||
require.NoError(t, err)
|
||||
|
||||
// IPv6 fails
|
||||
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no IPv6 connection")
|
||||
}
|
||||
|
||||
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
|
||||
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
|
||||
|
||||
// IPv6 works
|
||||
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
|
||||
require.NoError(t, err)
|
||||
|
||||
// IPv4 fails
|
||||
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no IPv4 connection")
|
||||
}
|
||||
|
||||
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
|
||||
// only reads from one socket (IPv4 preferred). This is fine because the actual
|
||||
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
|
||||
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
|
||||
ipv4Conn := &mockPacketConn{
|
||||
network: "udp4",
|
||||
readData: []byte("from ipv4"),
|
||||
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
|
||||
}
|
||||
ipv6Conn := &mockPacketConn{
|
||||
network: "udp6",
|
||||
readData: []byte("from ipv6"),
|
||||
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
|
||||
}
|
||||
|
||||
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
|
||||
|
||||
buf := make([]byte, 100)
|
||||
n, addr, err := dualStack.ReadFrom(buf)
|
||||
|
||||
require.NoError(t, err)
|
||||
// reads from IPv4 (preferred) - this is expected behavior
|
||||
assert.Equal(t, "from ipv4", string(buf[:n]))
|
||||
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
|
||||
}
|
||||
|
||||
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
|
||||
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
|
||||
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ipv4 net.PacketConn
|
||||
ipv6 net.PacketConn
|
||||
wantAddr net.Addr
|
||||
}{
|
||||
{
|
||||
name: "both available returns IPv4",
|
||||
ipv4: &mockPacketConn{localAddr: ipv4Addr},
|
||||
ipv6: &mockPacketConn{localAddr: ipv6Addr},
|
||||
wantAddr: ipv4Addr,
|
||||
},
|
||||
{
|
||||
name: "IPv4 only",
|
||||
ipv4: &mockPacketConn{localAddr: ipv4Addr},
|
||||
ipv6: nil,
|
||||
wantAddr: ipv4Addr,
|
||||
},
|
||||
{
|
||||
name: "IPv6 only",
|
||||
ipv4: nil,
|
||||
ipv6: &mockPacketConn{localAddr: ipv6Addr},
|
||||
wantAddr: ipv6Addr,
|
||||
},
|
||||
{
|
||||
name: "neither returns nil",
|
||||
ipv4: nil,
|
||||
ipv6: nil,
|
||||
wantAddr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
|
||||
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mock
|
||||
|
||||
type mockPacketConn struct {
|
||||
network string
|
||||
writeCount int
|
||||
readData []byte
|
||||
readAddr net.Addr
|
||||
localAddr net.Addr
|
||||
}
|
||||
|
||||
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
if m.readData != nil {
|
||||
return copy(b, m.readData), m.readAddr, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
m.writeCount++
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (m *mockPacketConn) Close() error { return nil }
|
||||
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
|
||||
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/pion/stun/v3"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
@@ -27,8 +26,8 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
@@ -58,6 +57,8 @@ type ICEBind struct {
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
ipv4Conn *net.UDPConn
|
||||
ipv6Conn *net.UDPConn
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
@@ -103,6 +104,12 @@ func (s *ICEBind) Close() error {
|
||||
|
||||
close(s.closedChan)
|
||||
|
||||
s.muUDPMux.Lock()
|
||||
s.ipv4Conn = nil
|
||||
s.ipv6Conn = nil
|
||||
s.udpMux = nil
|
||||
s.muUDPMux.Unlock()
|
||||
|
||||
return s.StdNetBind.Close()
|
||||
}
|
||||
|
||||
@@ -160,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(conn),
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
MTU: s.mtu,
|
||||
},
|
||||
)
|
||||
// Detect IPv4 vs IPv6 from connection's local address
|
||||
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil {
|
||||
s.ipv4Conn = conn
|
||||
} else {
|
||||
s.ipv6Conn = conn
|
||||
}
|
||||
s.createOrUpdateMux()
|
||||
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := getMessages(msgsPool)
|
||||
for i := range bufs {
|
||||
@@ -180,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer putMessages(msgs, msgsPool)
|
||||
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||
//nolint
|
||||
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||
//nolint:staticcheck
|
||||
_, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -207,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok {
|
||||
continue
|
||||
}
|
||||
sizes[i] = msg.N
|
||||
@@ -233,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
}
|
||||
}
|
||||
|
||||
// createOrUpdateMux creates or updates the UDP mux with the available connections.
|
||||
// Must be called with muUDPMux held.
|
||||
func (s *ICEBind) createOrUpdateMux() {
|
||||
var muxConn net.PacketConn
|
||||
|
||||
switch {
|
||||
case s.ipv4Conn != nil && s.ipv6Conn != nil:
|
||||
muxConn = NewDualStackPacketConn(
|
||||
nbnet.WrapPacketConn(s.ipv4Conn),
|
||||
nbnet.WrapPacketConn(s.ipv6Conn),
|
||||
)
|
||||
case s.ipv4Conn != nil:
|
||||
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
|
||||
case s.ipv6Conn != nil:
|
||||
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
// Don't close the old mux - it doesn't own the underlying connections.
|
||||
// The sockets are managed by WireGuard's StdNetBind, not by us.
|
||||
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: muxConn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
MTU: s.mtu,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||
for i := range buffers {
|
||||
if !stun.IsMessage(buffers[i]) {
|
||||
@@ -245,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
|
||||
return true, err
|
||||
}
|
||||
|
||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
||||
if muxErr != nil {
|
||||
log.Warnf("failed to handle STUN packet")
|
||||
s.muUDPMux.Lock()
|
||||
mux := s.udpMux
|
||||
s.muUDPMux.Unlock()
|
||||
|
||||
if mux != nil {
|
||||
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
|
||||
log.Warnf("failed to handle STUN packet: %v", muxErr)
|
||||
}
|
||||
}
|
||||
|
||||
buffers[i] = []byte{}
|
||||
|
||||
324
client/iface/bind/ice_bind_test.go
Normal file
324
client/iface/bind/ice_bind_test.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
|
||||
iceBind := setupICEBind(t)
|
||||
|
||||
ipv4Conn, ipv6Conn := createDualStackConns(t)
|
||||
defer ipv4Conn.Close()
|
||||
defer ipv6Conn.Close()
|
||||
|
||||
rc := receiverCreator{iceBind}
|
||||
pool := createMsgPool()
|
||||
|
||||
// Simulate wireguard-go calling CreateReceiverFn for IPv4
|
||||
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
|
||||
require.NotNil(t, ipv4RecvFn)
|
||||
|
||||
iceBind.muUDPMux.Lock()
|
||||
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
|
||||
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
|
||||
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
|
||||
iceBind.muUDPMux.Unlock()
|
||||
|
||||
// Simulate wireguard-go calling CreateReceiverFn for IPv6
|
||||
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
|
||||
require.NotNil(t, ipv6RecvFn)
|
||||
|
||||
iceBind.muUDPMux.Lock()
|
||||
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
|
||||
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
|
||||
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
|
||||
iceBind.muUDPMux.Unlock()
|
||||
|
||||
mux, err := iceBind.GetICEMux()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mux)
|
||||
}
|
||||
|
||||
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
|
||||
iceBind := setupICEBind(t)
|
||||
|
||||
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
require.NoError(t, err)
|
||||
defer ipv4Conn.Close()
|
||||
|
||||
rc := receiverCreator{iceBind}
|
||||
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
|
||||
require.NotNil(t, recvFn)
|
||||
|
||||
iceBind.muUDPMux.Lock()
|
||||
assert.NotNil(t, iceBind.ipv4Conn)
|
||||
assert.Nil(t, iceBind.ipv6Conn)
|
||||
assert.NotNil(t, iceBind.udpMux)
|
||||
iceBind.muUDPMux.Unlock()
|
||||
|
||||
mux, err := iceBind.GetICEMux()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mux)
|
||||
}
|
||||
|
||||
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
|
||||
iceBind := setupICEBind(t)
|
||||
|
||||
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer ipv6Conn.Close()
|
||||
|
||||
rc := receiverCreator{iceBind}
|
||||
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
|
||||
require.NotNil(t, recvFn)
|
||||
|
||||
iceBind.muUDPMux.Lock()
|
||||
assert.Nil(t, iceBind.ipv4Conn)
|
||||
assert.NotNil(t, iceBind.ipv6Conn)
|
||||
assert.NotNil(t, iceBind.udpMux)
|
||||
iceBind.muUDPMux.Unlock()
|
||||
|
||||
mux, err := iceBind.GetICEMux()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mux)
|
||||
}
|
||||
|
||||
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
|
||||
// with peers on different address families through the same DualStackPacketConn.
|
||||
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
|
||||
// two "remote peers" listening on different address families
|
||||
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||
defer ipv4Peer.Close()
|
||||
|
||||
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
|
||||
if err != nil {
|
||||
t.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer ipv6Peer.Close()
|
||||
|
||||
// our local dual-stack connection
|
||||
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||
defer ipv4Local.Close()
|
||||
|
||||
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
|
||||
defer ipv6Local.Close()
|
||||
|
||||
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
|
||||
|
||||
// send to both peers
|
||||
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify IPv4 peer got its packet from the IPv4 socket
|
||||
buf := make([]byte, 100)
|
||||
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
|
||||
n, addr, err := ipv4Peer.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "to-ipv4", string(buf[:n]))
|
||||
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
|
||||
|
||||
// verify IPv6 peer got its packet from the IPv6 socket
|
||||
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
|
||||
n, addr, err = ipv6Peer.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "to-ipv6", string(buf[:n]))
|
||||
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
|
||||
}
|
||||
|
||||
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
|
||||
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
|
||||
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
|
||||
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||
defer ipv4Peer.Close()
|
||||
|
||||
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
|
||||
if err != nil {
|
||||
t.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
defer ipv6Peer.Close()
|
||||
|
||||
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||
defer ipv4Local.Close()
|
||||
|
||||
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
|
||||
defer ipv6Local.Close()
|
||||
|
||||
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
|
||||
|
||||
const packetsPerFamily = 500
|
||||
|
||||
ipv4Received := make(chan string, packetsPerFamily)
|
||||
ipv6Received := make(chan string, packetsPerFamily)
|
||||
|
||||
startGate := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 100)
|
||||
for i := 0; i < packetsPerFamily; i++ {
|
||||
n, _, err := ipv4Peer.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ipv4Received <- string(buf[:n])
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 100)
|
||||
for i := 0; i < packetsPerFamily; i++ {
|
||||
n, _, err := ipv6Peer.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ipv6Received <- string(buf[:n])
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startGate
|
||||
for i := 0; i < packetsPerFamily; i++ {
|
||||
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startGate
|
||||
for i := 0; i < packetsPerFamily; i++ {
|
||||
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
|
||||
}
|
||||
}()
|
||||
|
||||
close(startGate)
|
||||
|
||||
time.AfterFunc(5*time.Second, func() {
|
||||
_ = ipv4Peer.SetReadDeadline(time.Now())
|
||||
_ = ipv6Peer.SetReadDeadline(time.Now())
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
close(ipv4Received)
|
||||
close(ipv6Received)
|
||||
|
||||
ipv4Count := 0
|
||||
for pkt := range ipv4Received {
|
||||
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
|
||||
ipv4Count++
|
||||
}
|
||||
|
||||
ipv6Count := 0
|
||||
for pkt := range ipv6Received {
|
||||
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
|
||||
ipv6Count++
|
||||
}
|
||||
|
||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||
}
|
||||
|
||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
network string
|
||||
addr string
|
||||
wantIPv4 bool
|
||||
}{
|
||||
{"IPv4 any", "udp4", "0.0.0.0:0", true},
|
||||
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
|
||||
{"IPv6 any", "udp6", "[::]:0", false},
|
||||
{"IPv6 loopback", "udp6", "[::1]:0", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.ListenUDP(tt.network, addr)
|
||||
if err != nil {
|
||||
t.Skipf("%s not available: %v", tt.network, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
isIPv4 := localAddr.IP.To4() != nil
|
||||
assert.Equal(t, tt.wantIPv4, isIPv4)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// helpers
|
||||
|
||||
func setupICEBind(t *testing.T) *ICEBind {
|
||||
t.Helper()
|
||||
transportNet, err := stdnet.NewNet()
|
||||
require.NoError(t, err)
|
||||
|
||||
address := wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/10"),
|
||||
}
|
||||
return NewICEBind(transportNet, nil, address, 1280)
|
||||
}
|
||||
|
||||
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
|
||||
t.Helper()
|
||||
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
require.NoError(t, err)
|
||||
|
||||
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||
if err != nil {
|
||||
ipv4Conn.Close()
|
||||
t.Skipf("IPv6 not available: %v", err)
|
||||
}
|
||||
return ipv4Conn, ipv6Conn
|
||||
}
|
||||
|
||||
func createMsgPool() *sync.Pool {
|
||||
return &sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv6.Message, 1)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, 0, 40)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
|
||||
t.Helper()
|
||||
udpAddr, err := net.ResolveUDPAddr(network, addr)
|
||||
require.NoError(t, err)
|
||||
conn, err := net.ListenUDP(network, udpAddr)
|
||||
require.NoError(t, err)
|
||||
return conn
|
||||
}
|
||||
@@ -3,8 +3,22 @@ package configurer
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer.
|
||||
// This is a shared helper used by both kernel and userspace configurers.
|
||||
func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config {
|
||||
return wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{{
|
||||
PublicKey: peerKey,
|
||||
PresharedKey: &psk,
|
||||
UpdateOnly: updateOnly,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
||||
ipNets := make([]net.IPNet, len(prefixes))
|
||||
for i, prefix := range prefixes {
|
||||
|
||||
@@ -15,8 +15,6 @@ import (
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
var zeroKey wgtypes.Key
|
||||
|
||||
type KernelConfigurer struct {
|
||||
deviceName string
|
||||
}
|
||||
@@ -48,6 +46,18 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetPresharedKey sets the preshared key for a peer.
|
||||
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
|
||||
func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
|
||||
return c.configure(cfg)
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
@@ -279,7 +289,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||
TxBytes: p.TransmitBytes,
|
||||
RxBytes: p.ReceiveBytes,
|
||||
LastHandshake: p.LastHandshakeTime,
|
||||
PresharedKey: p.PresharedKey != zeroKey,
|
||||
PresharedKey: [32]byte(p.PresharedKey),
|
||||
}
|
||||
if p.Endpoint != nil {
|
||||
peer.Endpoint = *p.Endpoint
|
||||
|
||||
@@ -22,17 +22,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
privateKey = "private_key"
|
||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||
ipcKeyTxBytes = "tx_bytes"
|
||||
ipcKeyRxBytes = "rx_bytes"
|
||||
allowedIP = "allowed_ip"
|
||||
endpoint = "endpoint"
|
||||
fwmark = "fwmark"
|
||||
listenPort = "listen_port"
|
||||
publicKey = "public_key"
|
||||
presharedKey = "preshared_key"
|
||||
privateKey = "private_key"
|
||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||
ipcKeyTxBytes = "tx_bytes"
|
||||
ipcKeyRxBytes = "rx_bytes"
|
||||
allowedIP = "allowed_ip"
|
||||
endpoint = "endpoint"
|
||||
fwmark = "fwmark"
|
||||
listenPort = "listen_port"
|
||||
publicKey = "public_key"
|
||||
presharedKey = "preshared_key"
|
||||
)
|
||||
|
||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||
@@ -72,6 +71,18 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
// SetPresharedKey sets the preshared key for a peer.
|
||||
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
|
||||
func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
|
||||
return c.device.IpcSet(toWgUserspaceString(cfg))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
@@ -422,23 +433,19 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
hexKey := hex.EncodeToString(p.PublicKey[:])
|
||||
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
|
||||
|
||||
if p.Remove {
|
||||
sb.WriteString("remove=true\n")
|
||||
}
|
||||
|
||||
if p.UpdateOnly {
|
||||
sb.WriteString("update_only=true\n")
|
||||
}
|
||||
|
||||
if p.PresharedKey != nil {
|
||||
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
|
||||
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
|
||||
}
|
||||
|
||||
if p.Remove {
|
||||
sb.WriteString("remove=true")
|
||||
}
|
||||
|
||||
if p.ReplaceAllowedIPs {
|
||||
sb.WriteString("replace_allowed_ips=true\n")
|
||||
}
|
||||
|
||||
for _, aip := range p.AllowedIPs {
|
||||
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
|
||||
}
|
||||
|
||||
if p.Endpoint != nil {
|
||||
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
|
||||
}
|
||||
@@ -446,6 +453,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
if p.PersistentKeepaliveInterval != nil {
|
||||
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
|
||||
}
|
||||
|
||||
if p.ReplaceAllowedIPs {
|
||||
sb.WriteString("replace_allowed_ips=true\n")
|
||||
}
|
||||
|
||||
for _, aip := range p.AllowedIPs {
|
||||
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -543,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||
host, portStr, err := net.SplitHostPort(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint: %v", err)
|
||||
continue
|
||||
@@ -599,7 +614,9 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||
continue
|
||||
}
|
||||
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
|
||||
currentPeer.PresharedKey = true
|
||||
if pskKey, err := hexToWireguardKey(val); err == nil {
|
||||
currentPeer.PresharedKey = [32]byte(pskKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ type Peer struct {
|
||||
TxBytes int64
|
||||
RxBytes int64
|
||||
LastHandshake time.Time
|
||||
PresharedKey bool
|
||||
PresharedKey [32]byte
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
|
||||
@@ -29,8 +29,9 @@ type PacketFilter interface {
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
|
||||
filter PacketFilter
|
||||
mutex sync.RWMutex
|
||||
filter PacketFilter
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// newDeviceFilter constructor function
|
||||
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying tun device exactly once.
|
||||
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
|
||||
// and multiple code paths can trigger Close on the same device.
|
||||
func (d *FilteredDevice) Close() error {
|
||||
var err error
|
||||
d.closeOnce.Do(func() {
|
||||
err = d.Device.Close()
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read wraps read method with filtering feature
|
||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
//go:build ios
|
||||
// +build ios
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -45,10 +43,31 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
|
||||
}
|
||||
}
|
||||
|
||||
// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
|
||||
// This typically means the Swift code couldn't find the utun control socket.
|
||||
var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
|
||||
|
||||
func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
log.Infof("create tun interface")
|
||||
|
||||
dupTunFd, err := unix.Dup(t.tunFd)
|
||||
var tunDevice tun.Device
|
||||
var err error
|
||||
|
||||
// Validate the tunnel file descriptor.
|
||||
// On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
|
||||
// A value of 0 means the Swift code couldn't find the utun control socket
|
||||
// (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
|
||||
// tvOS SDK headers). This is a hard error - there's no viable fallback
|
||||
// since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
|
||||
if t.tunFd == 0 {
|
||||
log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
|
||||
"On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
|
||||
return nil, ErrInvalidTunnelFD
|
||||
}
|
||||
|
||||
// Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
|
||||
var dupTunFd int
|
||||
dupTunFd, err = unix.Dup(t.tunFd)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to dup tun fd: %v", err)
|
||||
return nil, err
|
||||
@@ -60,7 +79,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
_ = unix.Close(dupTunFd)
|
||||
return nil, err
|
||||
}
|
||||
tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
|
||||
tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to create new tun device from fd: %v", err)
|
||||
_ = unix.Close(dupTunFd)
|
||||
|
||||
@@ -82,7 +82,9 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
if cErr := tunIface.Close(); cErr != nil {
|
||||
log.Debugf("failed to close tun device: %v", cErr)
|
||||
}
|
||||
return nil, fmt.Errorf("error configuring interface: %s", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ type WGConfigurer interface {
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||
Close()
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
@@ -50,6 +51,7 @@ func ValidateMTU(mtu uint16) error {
|
||||
|
||||
type wgProxyFactory interface {
|
||||
GetProxy() wgproxy.Proxy
|
||||
GetProxyPort() uint16
|
||||
Free() error
|
||||
}
|
||||
|
||||
@@ -80,6 +82,12 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||
return w.wgProxyFactory.GetProxy()
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy port used by the WireGuard proxy.
|
||||
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
|
||||
func (w *WGIface) GetProxyPort() uint16 {
|
||||
return w.wgProxyFactory.GetProxyPort()
|
||||
}
|
||||
|
||||
// GetBind returns the EndpointManager userspace bind mode.
|
||||
func (w *WGIface) GetBind() device.EndpointManager {
|
||||
w.mu.Lock()
|
||||
@@ -221,6 +229,10 @@ func (w *WGIface) Close() error {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||
}
|
||||
|
||||
if nbnetstack.IsEnabled() {
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
if err := w.waitUntilRemoved(); err != nil {
|
||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||
if err := w.Destroy(); err != nil {
|
||||
@@ -297,6 +309,19 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||
return w.configurer.FullStats()
|
||||
}
|
||||
|
||||
// SetPresharedKey sets or updates the preshared key for a peer.
|
||||
// If updateOnly is true, only updates existing peer; if false, creates or updates.
|
||||
func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
return w.configurer.SetPresharedKey(peerKey, psk, updateOnly)
|
||||
}
|
||||
|
||||
func (w *WGIface) waitUntilRemoved() error {
|
||||
maxWaitTime := 5 * time.Second
|
||||
timeout := time.NewTimer(maxWaitTime)
|
||||
|
||||
@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return nsTunDev, tunNet, nil
|
||||
return t.tundev, tunNet, nil
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Close() error {
|
||||
|
||||
@@ -114,21 +114,21 @@ func (p *ProxyBind) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
ep, err := addrToEndpoint(endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to start package redirection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgCurrentUsed = addrToEndpoint(endpoint)
|
||||
p.wgCurrentUsed = ep
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
@@ -212,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||
return &netipAddr, nil
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
|
||||
if addr == nil {
|
||||
return nil, fmt.Errorf("invalid address")
|
||||
}
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -26,13 +24,10 @@ const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
var (
|
||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
proxyPort int
|
||||
mtu uint16
|
||||
|
||||
ebpfManager ebpfMgr.Manager
|
||||
@@ -40,7 +35,8 @@ type WGEBPFProxy struct {
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
lastUsedPort uint16
|
||||
rawConn net.PacketConn
|
||||
rawConnIPv4 net.PacketConn
|
||||
rawConnIPv6 net.PacketConn
|
||||
conn transport.UDPConn
|
||||
|
||||
ctx context.Context
|
||||
@@ -62,23 +58,39 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
pl := portLookup{}
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
proxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.proxyPort = proxyPort
|
||||
|
||||
// Prepare IPv4 raw socket (required)
|
||||
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||
// Prepare IPv6 raw socket (optional)
|
||||
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
if err != nil {
|
||||
return err
|
||||
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
|
||||
}
|
||||
|
||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
||||
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
|
||||
if err != nil {
|
||||
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
|
||||
}
|
||||
if p.rawConnIPv6 != nil {
|
||||
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: wgPorxyPort,
|
||||
Port: proxyPort,
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
}
|
||||
|
||||
@@ -94,7 +106,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
p.conn = conn
|
||||
|
||||
go p.proxyToRemote()
|
||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||
log.Infof("local wg proxy listening on: %d", proxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -135,12 +147,25 @@ func (p *WGEBPFProxy) Free() error {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if err := p.rawConn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
if p.rawConnIPv4 != nil {
|
||||
if err := p.rawConnIPv4.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
if p.rawConnIPv6 != nil {
|
||||
if err := p.rawConnIPv6.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the proxy listening port.
|
||||
func (p *WGEBPFProxy) GetProxyPort() uint16 {
|
||||
return uint16(p.proxyPort)
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
@@ -216,34 +241,3 @@ generatePort:
|
||||
}
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localHostNetIP,
|
||||
SrcIP: endpointAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
layerBuffer := gopacket.NewSerializeBuffer()
|
||||
|
||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,12 +10,89 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
var (
|
||||
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
|
||||
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
|
||||
|
||||
localHostNetIPv4 = net.ParseIP("127.0.0.1")
|
||||
localHostNetIPv6 = net.ParseIP("::1")
|
||||
|
||||
serializeOpts = gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
)
|
||||
|
||||
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
|
||||
type PacketHeaders struct {
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH *layers.UDP
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
localHostAddr net.IP
|
||||
isIPv4 bool
|
||||
}
|
||||
|
||||
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
var localHostAddr net.IP
|
||||
var isIPv4 bool
|
||||
|
||||
// Check if source address is IPv4 or IPv6
|
||||
if endpoint.IP.To4() != nil {
|
||||
// IPv4 path
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPv4,
|
||||
SrcIP: endpoint.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
localHostAddr = localHostNetIPv4
|
||||
isIPv4 = true
|
||||
} else {
|
||||
// IPv6 path
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPv6,
|
||||
SrcIP: endpoint.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
localHostAddr = localHostNetIPv6
|
||||
isIPv4 = false
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(endpoint.Port),
|
||||
DstPort: layers.UDPPort(localWGListenPort),
|
||||
}
|
||||
|
||||
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
return &PacketHeaders{
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
localHostAddr: localHostAddr,
|
||||
isIPv4: isIPv4,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
wgeBPFProxy *WGEBPFProxy
|
||||
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
headers *PacketHeaders
|
||||
headerCurrentUsed *PacketHeaders
|
||||
rawConn net.PacketConn
|
||||
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
}
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
|
||||
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create packet sender: %w", err)
|
||||
}
|
||||
|
||||
// Check if required raw connection is available
|
||||
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||
return errIPv6ConnNotAvailable
|
||||
}
|
||||
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||
return errIPv4ConnNotAvailable
|
||||
}
|
||||
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgRelayedEndpointAddr = addr
|
||||
return err
|
||||
p.headers = headers
|
||||
p.rawConn = p.selectRawConn(headers)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
||||
p.headerCurrentUsed = p.headers
|
||||
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
@@ -91,10 +188,32 @@ func (p *ProxyWrapper) Pause() {
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
if endpoint == nil || endpoint.IP == nil {
|
||||
log.Errorf("failed to start package redirection, endpoint is nil")
|
||||
return
|
||||
}
|
||||
|
||||
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create packet headers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if required raw connection is available
|
||||
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
|
||||
log.Error(errIPv6ConnNotAvailable)
|
||||
return
|
||||
}
|
||||
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
|
||||
log.Error(errIPv4ConnNotAvailable)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
p.headerCurrentUsed = header
|
||||
p.rawConn = p.selectRawConn(header)
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
@@ -136,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
||||
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
@@ -162,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
|
||||
defer func() {
|
||||
if err := header.layerBuffer.Clear(); err != nil {
|
||||
log.Errorf("failed to clear layer buffer: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
|
||||
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
|
||||
if header.isIPv4 {
|
||||
return p.wgeBPFProxy.rawConnIPv4
|
||||
}
|
||||
return p.wgeBPFProxy.rawConnIPv6
|
||||
}
|
||||
|
||||
@@ -3,12 +3,19 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
envDisableEBPFWGProxy = "NB_DISABLE_EBPF_WG_PROXY"
|
||||
)
|
||||
|
||||
type KernelFactory struct {
|
||||
wgPort int
|
||||
mtu uint16
|
||||
@@ -22,6 +29,12 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
if isEBPFDisabled() {
|
||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||
log.Infof("eBPF WireGuard proxy is disabled via %s environment variable", envDisableEBPFWGProxy)
|
||||
return f
|
||||
}
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||
@@ -41,9 +54,30 @@ func (w *KernelFactory) GetProxy() Proxy {
|
||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||
}
|
||||
|
||||
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
|
||||
func (w *KernelFactory) GetProxyPort() uint16 {
|
||||
if w.ebpfProxy == nil {
|
||||
return 0
|
||||
}
|
||||
return w.ebpfProxy.GetProxyPort()
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
}
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
||||
|
||||
func isEBPFDisabled() bool {
|
||||
val := os.Getenv(envDisableEBPFWGProxy)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envDisableEBPFWGProxy, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
@@ -24,6 +24,11 @@ func (w *USPFactory) GetProxy() Proxy {
|
||||
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||
}
|
||||
|
||||
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
|
||||
func (w *USPFactory) GetProxyPort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,43 +8,87 @@ import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
|
||||
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET, true)
|
||||
}
|
||||
|
||||
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
|
||||
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
|
||||
return prepareSenderRawSocket(syscall.AF_INET6, false)
|
||||
}
|
||||
|
||||
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
|
||||
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
|
||||
if isIPv4 {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
if closeErr := syscall.Close(fd); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket fd: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
// Close the original file to release the FD (net.FilePacketConn duplicates it)
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
353
client/iface/wgproxy/redirect_test.go
Normal file
353
client/iface/wgproxy/redirect_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
|
||||
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
|
||||
func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||
udpAddr1, ok1 := addr1.(*net.UDPAddr)
|
||||
udpAddr2, ok2 := addr2.(*net.UDPAddr)
|
||||
|
||||
if !ok1 || !ok2 {
|
||||
return addr1.String() == addr2.String()
|
||||
}
|
||||
|
||||
// Compare IP and Port, ignoring zone
|
||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||
wgPort := 51852
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
|
||||
func TestRedirectAs_UDP_IPv6(t *testing.T) {
|
||||
wgPort := 51853
|
||||
proxy := udp.NewWGUDPProxy(wgPort, 1280)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// testRedirectAs is a helper function that tests the RedirectAs functionality
|
||||
// It verifies that:
|
||||
// 1. Initial traffic from relay connection works
|
||||
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
|
||||
// 3. Multiple packets are correctly redirected with the new source address
|
||||
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
|
||||
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
|
||||
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener4.Close()
|
||||
|
||||
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener6.Close()
|
||||
|
||||
// Determine which listener to use based on the NetBird address IP version
|
||||
// (this is where initial traffic will come from before RedirectAs is called)
|
||||
var wgListener *net.UDPConn
|
||||
if p2pEndpoint.IP.To4() == nil {
|
||||
wgListener = wgListener6
|
||||
} else {
|
||||
wgListener = wgListener4
|
||||
}
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0, // Random port
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
// Add TURN connection to proxy
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the proxy
|
||||
proxy.Work()
|
||||
|
||||
// Phase 1: Test initial relay traffic
|
||||
msgFromRelay := []byte("hello from relay")
|
||||
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write to relay server: %v", err)
|
||||
}
|
||||
|
||||
// Set read deadline to avoid hanging
|
||||
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, _, err := wgListener4.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from WireGuard listener: %v", err)
|
||||
}
|
||||
|
||||
if n != len(msgFromRelay) {
|
||||
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msgFromRelay) {
|
||||
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
|
||||
}
|
||||
|
||||
// Phase 2: Redirect to p2p endpoint
|
||||
proxy.RedirectAs(p2pEndpoint)
|
||||
|
||||
// Give the proxy a moment to process the redirect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Phase 3: Test redirected traffic
|
||||
redirectedMessages := [][]byte{
|
||||
[]byte("redirected message 1"),
|
||||
[]byte("redirected message 2"),
|
||||
[]byte("redirected message 3"),
|
||||
}
|
||||
|
||||
for i, msg := range redirectedMessages {
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
|
||||
}
|
||||
|
||||
// Verify message content
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
|
||||
}
|
||||
|
||||
// Verify source address matches p2p endpoint (this is the key test)
|
||||
// Use compareUDPAddr to ignore IPv6 zone IDs
|
||||
if !compareUDPAddr(srcAddr, p2pEndpoint) {
|
||||
t.Errorf("message %d: expected source address %s, got %s",
|
||||
i+1, p2pEndpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||
wgPort := 51856
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create WireGuard listener
|
||||
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create WireGuard listener: %v", err)
|
||||
}
|
||||
defer wgListener.Close()
|
||||
|
||||
// Create relay server and connection
|
||||
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay server: %v", err)
|
||||
}
|
||||
defer relayServer.Close()
|
||||
|
||||
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create relay connection: %v", err)
|
||||
}
|
||||
defer relayConn.Close()
|
||||
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
|
||||
t.Fatalf("failed to add TURN connection: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("failed to close proxy connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy.Work()
|
||||
|
||||
// Test switching between multiple endpoints - using addresses in local subnet
|
||||
endpoints := []*net.UDPAddr{
|
||||
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
|
||||
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
|
||||
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
|
||||
}
|
||||
|
||||
for i, endpoint := range endpoints {
|
||||
proxy.RedirectAs(endpoint)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
msg := []byte("test message")
|
||||
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
|
||||
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("failed to set read deadline: %v", err)
|
||||
}
|
||||
|
||||
n, srcAddr, err := wgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
|
||||
}
|
||||
|
||||
if !compareUDPAddr(srcAddr, endpoint) {
|
||||
t.Errorf("endpoint %d: expected source %s, got %s",
|
||||
i, endpoint.String(), srcAddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
|
||||
@@ -19,37 +19,56 @@ var (
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
localHostNetIPAddr = &net.IPAddr{
|
||||
localHostNetIPAddrV4 = &net.IPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
localHostNetIPAddrV6 = &net.IPAddr{
|
||||
IP: net.ParseIP("::1"),
|
||||
}
|
||||
)
|
||||
|
||||
type SrcFaker struct {
|
||||
srcAddr *net.UDPAddr
|
||||
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
localHostAddr *net.IPAddr
|
||||
}
|
||||
|
||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||
// Create only the raw socket for the address family we need
|
||||
var rawSocket net.PacketConn
|
||||
var err error
|
||||
var localHostAddr *net.IPAddr
|
||||
|
||||
if srcAddr.IP.To4() != nil {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
|
||||
localHostAddr = localHostNetIPAddrV4
|
||||
} else {
|
||||
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
|
||||
localHostAddr = localHostNetIPAddrV6
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||
if err != nil {
|
||||
if closeErr := rawSocket.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close raw socket: %v", closeErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &SrcFaker{
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
localHostAddr: localHostAddr,
|
||||
}
|
||||
|
||||
return f, nil
|
||||
@@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
@@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: net.ParseIP("127.0.0.1"),
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
var ipH gopacket.SerializableLayer
|
||||
var networkLayer gopacket.NetworkLayer
|
||||
|
||||
// Check if source IP is IPv4 or IPv6
|
||||
if srcAddr.IP.To4() != nil {
|
||||
// IPv4
|
||||
ipv4 := &layers.IPv4{
|
||||
DstIP: localHostNetIPAddrV4.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv4
|
||||
networkLayer = ipv4
|
||||
} else {
|
||||
// IPv6
|
||||
ipv6 := &layers.IPv6{
|
||||
DstIP: localHostNetIPAddrV6.IP,
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
}
|
||||
ipH = ipv6
|
||||
networkLayer = ipv6
|
||||
}
|
||||
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
err := udpH.SetNetworkLayerForChecksum(networkLayer)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
||||
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
||||
// This tests the full ACL manager -> uspfilter integration.
|
||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// Apply the same rules 5 times (simulating repeated network map updates)
|
||||
for i := 0; i < 5; i++ {
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
}
|
||||
|
||||
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
||||
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
||||
"Should have exactly 3 rule pairs after 5 identical updates")
|
||||
}
|
||||
|
||||
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
||||
// up when they're removed from the network map in a subsequent update.
|
||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First update: add deny and accept rules
|
||||
networkMap1 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap1, false)
|
||||
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
||||
|
||||
// Second update: remove the deny rule, keep only accept
|
||||
networkMap2 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap2, false)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"Should have 1 rule after removing deny rule")
|
||||
|
||||
// Third update: remove all rules
|
||||
networkMap3 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap3, false)
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
||||
"Should have 0 rules after removing all rules")
|
||||
}
|
||||
|
||||
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
||||
// accept to deny (or vice versa), the old rule is properly removed and the new
|
||||
// one added without leaking.
|
||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, fw.Close(nil))
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First update: accept rule
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||
|
||||
// Second update: change to deny (same IP/port/proto, different action)
|
||||
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
}
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"Changing action should result in exactly 1 rule, not 2")
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
499
client/internal/auth/auth.go
Normal file
499
client/internal/auth/auth.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Auth manages authentication operations with the management server
|
||||
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||
type Auth struct {
|
||||
mutex sync.RWMutex
|
||||
client *mgm.GrpcClient
|
||||
config *profilemanager.Config
|
||||
privateKey wgtypes.Key
|
||||
mgmURL *url.URL
|
||||
mgmTLSEnabled bool
|
||||
}
|
||||
|
||||
// NewAuth creates a new Auth instance that manages authentication flows
|
||||
// It establishes a connection to the management server that will be reused for all operations
|
||||
// The connection is automatically recreated with backoff if it becomes disconnected
|
||||
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||
// Validate WireGuard private key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine TLS setting based on URL scheme
|
||||
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
return &Auth{
|
||||
client: mgmClient,
|
||||
config: config,
|
||||
privateKey: myPrivateKey,
|
||||
mgmURL: mgmURL,
|
||||
mgmTLSEnabled: mgmTLSEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the management client connection
|
||||
func (a *Auth) Close() error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
if a.client == nil {
|
||||
return nil
|
||||
}
|
||||
return a.client.Close()
|
||||
}
|
||||
|
||||
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||
var supportsSSO bool
|
||||
|
||||
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
// Try PKCE flow first
|
||||
_, err := a.getPKCEFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if PKCE is not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// PKCE not supported, try Device flow
|
||||
_, err = a.getDeviceFlow(client)
|
||||
if err == nil {
|
||||
supportsSSO = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if Device flow is also not supported
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
// Neither PKCE nor Device flow is supported
|
||||
supportsSSO = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Device flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
}
|
||||
|
||||
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||
return err
|
||||
})
|
||||
|
||||
return supportsSSO, err
|
||||
}
|
||||
|
||||
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
||||
// This avoids creating a new connection to the management server
|
||||
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
||||
var flow OAuthFlow
|
||||
var err error
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
if forceDeviceAuth {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
|
||||
// Try PKCE flow first
|
||||
flow, err = a.getPKCEFlow(client)
|
||||
if err != nil {
|
||||
// If PKCE not supported, try Device flow
|
||||
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
flow, err = a.getDeviceFlow(client)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return flow, err
|
||||
}
|
||||
|
||||
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
}
|
||||
needsLogin = false
|
||||
return err
|
||||
})
|
||||
|
||||
return needsLogin, err
|
||||
}
|
||||
|
||||
// Login attempts to log in or register the client with the management server
|
||||
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
||||
// Automatically retries with backoff and reconnection on connection errors.
|
||||
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
isAuthError = isPermissionDenied(err)
|
||||
return err
|
||||
}
|
||||
|
||||
isAuthError = false
|
||||
return nil
|
||||
})
|
||||
|
||||
return err, isAuthError
|
||||
}
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
ClientCertPair: a.config.ClientCertKeyPair,
|
||||
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||
}
|
||||
|
||||
if err := validatePKCEConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConfig := protoFlow.GetProviderConfig()
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
UseIDToken: protoConfig.GetUseIDToken(),
|
||||
}
|
||||
|
||||
// Keep compatibility with older management versions
|
||||
if config.Scope == "" {
|
||||
config.Scope = "openid"
|
||||
}
|
||||
|
||||
if err := validateDeviceAuthConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flow, nil
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("peer has been successfully registered on Management Service")
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// setSystemInfoFlags sets all configuration flags on the provided system info
|
||||
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
a.config.DisableFirewall,
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
a.config.EnableSSHRemotePortForwarding,
|
||||
a.config.DisableSSHAuth,
|
||||
)
|
||||
}
|
||||
|
||||
// reconnect closes the current connection and creates a new one
|
||||
// It checks if the brokenClient is still the current client before reconnecting
|
||||
// to avoid multiple threads reconnecting unnecessarily
|
||||
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
// Double-check: if client has already been replaced by another thread, skip reconnection
|
||||
if a.client != brokenClient {
|
||||
log.Debugf("client already reconnected by another thread, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new connection FIRST, before closing the old one
|
||||
// This ensures a.client is never nil, preventing panics in other threads
|
||||
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
||||
// Keep the old client if reconnection fails
|
||||
return err
|
||||
}
|
||||
|
||||
// Close old connection AFTER new one is successfully created
|
||||
oldClient := a.client
|
||||
a.client = mgmClient
|
||||
|
||||
if oldClient != nil {
|
||||
if err := oldClient.Close(); err != nil {
|
||||
log.Debugf("error closing old connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// These error codes indicate connection issues
|
||||
return s.Code() == codes.Unavailable ||
|
||||
s.Code() == codes.DeadlineExceeded ||
|
||||
s.Code() == codes.Canceled ||
|
||||
s.Code() == codes.Internal
|
||||
}
|
||||
|
||||
// withRetry wraps an operation with exponential backoff retry logic
|
||||
// It automatically reconnects on connection errors
|
||||
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
||||
backoffSettings := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.5,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 2 * time.Minute,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
backoffSettings.Reset()
|
||||
|
||||
return backoff.RetryNotify(
|
||||
func() error {
|
||||
// Capture the client BEFORE the operation to ensure we track the correct client
|
||||
a.mutex.RLock()
|
||||
currentClient := a.client
|
||||
a.mutex.RUnlock()
|
||||
|
||||
if currentClient == nil {
|
||||
return status.Errorf(codes.Unavailable, "client is not initialized")
|
||||
}
|
||||
|
||||
// Execute operation with the captured client
|
||||
err := operation(currentClient)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's a connection error, attempt reconnection using the client that was actually used
|
||||
if isConnectionError(err) {
|
||||
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||
|
||||
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
||||
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||
return reconnectErr
|
||||
}
|
||||
// Return the original error to trigger retry with the new connection
|
||||
return err
|
||||
}
|
||||
|
||||
// For authentication errors, don't retry
|
||||
if isAuthenticationError(err) {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
backoff.WithContext(backoffSettings, ctx),
|
||||
func(err error, duration time.Duration) {
|
||||
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
||||
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
||||
func isAuthenticationError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
// isPermissionDenied checks if the error is a PermissionDenied error.
|
||||
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
||||
func isPermissionDenied(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return s.Code() == codes.PermissionDenied
|
||||
}
|
||||
|
||||
func isLoginNeeded(err error) bool {
|
||||
return isAuthenticationError(err)
|
||||
}
|
||||
|
||||
func isRegistrationNeeded(err error) bool {
|
||||
return isPermissionDenied(err)
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
@@ -26,12 +25,56 @@ const (
|
||||
|
||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||
// for the Device Authorization Flow.
|
||||
type DeviceAuthorizationFlow struct {
|
||||
providerConfig internal.DeviceAuthProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
providerConfig DeviceAuthProviderConfig
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
|
||||
}
|
||||
|
||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
||||
return d.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the device authorization flow
|
||||
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||
d.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
form := url.Values{}
|
||||
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
interval := time.Duration(info.Interval) * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case <-ticker.C:
|
||||
|
||||
tokenResponse, err := d.requestToken(info)
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
err: testCase.inputReqError,
|
||||
}
|
||||
|
||||
deviceFlow := &DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
countResBody: testCase.inputCountResBody,
|
||||
}
|
||||
|
||||
deviceFlow := DeviceAuthorizationFlow{
|
||||
providerConfig: internal.DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
config := DeviceAuthProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
}
|
||||
|
||||
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||
require.NoError(t, err, "creating device flow should not fail")
|
||||
deviceFlow.HTTPClient = &httpClient
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
if hint != "" {
|
||||
pkceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
return pkceFlowInfo, nil
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||
}
|
||||
defer authClient.Close()
|
||||
|
||||
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
case ok && s.Code() == codes.NotFound:
|
||||
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
if hint != "" {
|
||||
deviceFlowInfo.SetLoginHint(hint)
|
||||
}
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
return deviceFlowInfo, nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
@@ -35,17 +34,67 @@ const (
|
||||
defaultPKCETimeoutSeconds = 300
|
||||
)
|
||||
|
||||
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||
type PKCEAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// ClientCertPair is used for mTLS authentication to the IDP
|
||||
ClientCertPair *tls.Certificate
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// validatePKCEConfig validates PKCE provider configuration
|
||||
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||
}
|
||||
if config.AuthorizationEndpoint == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||
}
|
||||
if config.RedirectURLs == nil {
|
||||
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||
// the Authorization Code Flow with PKCE.
|
||||
type PKCEAuthorizationFlow struct {
|
||||
providerConfig internal.PKCEAuthProviderConfig
|
||||
providerConfig PKCEAuthProviderConfig
|
||||
state string
|
||||
codeVerifier string
|
||||
oAuthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
excludedRanges := getSystemExcludedPortRanges()
|
||||
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||
p.providerConfig.LoginHint = hint
|
||||
}
|
||||
|
||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
||||
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||
// Create timeout context based on flow expiration
|
||||
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
go p.startServer(server, tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-waitCtx.Done():
|
||||
return TokenInfo{}, waitCtx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.parseOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
config := PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestParseExcludedPortRanges(t *testing.T) {
|
||||
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||
|
||||
availablePort := 65432
|
||||
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
config := PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -59,7 +60,6 @@ func NewConnectClient(
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
doInitalAutoUpdate bool,
|
||||
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
@@ -71,8 +71,8 @@ func NewConnectClient(
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
||||
return c.run(MobileDependency{}, runningChan)
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
@@ -93,7 +93,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
@@ -111,10 +111,10 @@ func (c *ConnectClient) RunOniOS(
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -245,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
||||
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||
}
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
@@ -284,7 +284,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -420,6 +420,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
|
||||
return syncResponse, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager if the engine is running.
|
||||
func (c *ConnectClient) SetLogLevel(level log.Level) {
|
||||
engine := c.Engine()
|
||||
if engine == nil {
|
||||
return
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager != nil {
|
||||
fwManager.SetLogLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current client status
|
||||
func (c *ConnectClient) Status() StatusType {
|
||||
if c == nil {
|
||||
@@ -459,7 +472,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
|
||||
nm := false
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
@@ -494,7 +507,10 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
LogPath: logPath,
|
||||
|
||||
ProfileConfig: config,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
|
||||
@@ -28,8 +28,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
@@ -57,6 +59,7 @@ block.prof: Block profiling information.
|
||||
heap.prof: Heap profiling information (snapshot of memory allocations).
|
||||
allocs.prof: Allocations profiling information.
|
||||
threadcreate.prof: Thread creation profiling information.
|
||||
cpu.prof: CPU profiling information.
|
||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||
|
||||
|
||||
@@ -223,10 +226,11 @@ type BundleGenerator struct {
|
||||
internalConfig *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logFile string
|
||||
logPath string
|
||||
cpuProfile []byte
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
|
||||
anonymize bool
|
||||
clientStatus string
|
||||
includeSystemInfo bool
|
||||
logFileCount uint32
|
||||
|
||||
@@ -235,7 +239,6 @@ type BundleGenerator struct {
|
||||
|
||||
type BundleConfig struct {
|
||||
Anonymize bool
|
||||
ClientStatus string
|
||||
IncludeSystemInfo bool
|
||||
LogFileCount uint32
|
||||
}
|
||||
@@ -244,7 +247,9 @@ type GeneratorDependencies struct {
|
||||
InternalConfig *profilemanager.Config
|
||||
StatusRecorder *peer.Status
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogFile string
|
||||
LogPath string
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
}
|
||||
|
||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||
@@ -260,10 +265,11 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
internalConfig: deps.InternalConfig,
|
||||
statusRecorder: deps.StatusRecorder,
|
||||
syncResponse: deps.SyncResponse,
|
||||
logFile: deps.LogFile,
|
||||
logPath: deps.LogPath,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
|
||||
anonymize: cfg.Anonymize,
|
||||
clientStatus: cfg.ClientStatus,
|
||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||
logFileCount: logFileCount,
|
||||
}
|
||||
@@ -309,13 +315,6 @@ func (g *BundleGenerator) createArchive() error {
|
||||
return fmt.Errorf("add status: %w", err)
|
||||
}
|
||||
|
||||
if g.statusRecorder != nil {
|
||||
status := g.statusRecorder.GetFullStatus()
|
||||
seedFromStatus(g.anonymizer, &status)
|
||||
} else {
|
||||
log.Debugf("no status recorder available for seeding")
|
||||
}
|
||||
|
||||
if err := g.addConfig(); err != nil {
|
||||
log.Errorf("failed to add config to debug bundle: %v", err)
|
||||
}
|
||||
@@ -332,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addCPUProfile(); err != nil {
|
||||
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addStackTrace(); err != nil {
|
||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||
}
|
||||
@@ -352,7 +355,7 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add wg show output: %v", err)
|
||||
}
|
||||
|
||||
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
|
||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||
if err := g.trySystemdLogFallback(); err != nil {
|
||||
@@ -401,11 +404,30 @@ func (g *BundleGenerator) addReadme() error {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStatus() error {
|
||||
if status := g.clientStatus; status != "" {
|
||||
statusReader := strings.NewReader(status)
|
||||
if g.statusRecorder != nil {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
if g.refreshStatus != nil {
|
||||
g.refreshStatus()
|
||||
}
|
||||
|
||||
fullStatus := g.statusRecorder.GetFullStatus()
|
||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
|
||||
statusOutput := overview.FullDetailSummary()
|
||||
|
||||
statusReader := strings.NewReader(statusOutput)
|
||||
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
|
||||
return fmt.Errorf("add status file to zip: %w", err)
|
||||
}
|
||||
seedFromStatus(g.anonymizer, &fullStatus)
|
||||
} else {
|
||||
log.Debugf("no status recorder available for seeding")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -535,6 +557,19 @@ func (g *BundleGenerator) addProf() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCPUProfile() error {
|
||||
if len(g.cpuProfile) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(g.cpuProfile)
|
||||
if err := g.addFileToZip(reader, "cpu.prof"); err != nil {
|
||||
return fmt.Errorf("add CPU profile to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStackTrace() error {
|
||||
buf := make([]byte, 5242880) // 5 MB buffer
|
||||
n := runtime.Stack(buf, true)
|
||||
@@ -710,14 +745,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addLogfile() error {
|
||||
if g.logFile == "" {
|
||||
if g.logPath == "" {
|
||||
log.Debugf("skipping empty log file in debug bundle")
|
||||
return nil
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(g.logFile)
|
||||
logDir := filepath.Dir(g.logPath)
|
||||
|
||||
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
|
||||
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
|
||||
return fmt.Errorf("add client log file to zip: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -507,15 +507,13 @@ func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
|
||||
if p.Base == expr.PayloadBaseNetworkHeader {
|
||||
switch p.Offset {
|
||||
case 12:
|
||||
if p.Len == 4 {
|
||||
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
} else if p.Len == 2 {
|
||||
switch p.Len {
|
||||
case 4, 2:
|
||||
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
}
|
||||
case 16:
|
||||
if p.Len == 4 {
|
||||
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
} else if p.Len == 2 {
|
||||
switch p.Len {
|
||||
case 4, 2:
|
||||
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
}
|
||||
}
|
||||
|
||||
101
client/internal/debug/upload.go
Normal file
101
client/internal/debug/upload.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const maxBundleUploadSize = 50 * 1024 * 1024
|
||||
|
||||
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
||||
response, err := getUploadURL(ctx, url, managementURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = upload(ctx, filePath, response)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return response.Key, nil
|
||||
}
|
||||
|
||||
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
||||
fileData, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
|
||||
defer fileData.Close()
|
||||
|
||||
stat, err := fileData.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat file: %w", err)
|
||||
}
|
||||
|
||||
if stat.Size() > maxBundleUploadSize {
|
||||
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create PUT request: %w", err)
|
||||
}
|
||||
|
||||
req.ContentLength = stat.Size()
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
putResp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload failed: %v", err)
|
||||
}
|
||||
defer putResp.Body.Close()
|
||||
|
||||
if putResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(putResp.Body)
|
||||
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
||||
id := getURLHash(managementURL)
|
||||
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create GET request: %w", err)
|
||||
}
|
||||
|
||||
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(getReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get presigned URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
urlBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
var response types.GetURLResponse
|
||||
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func getURLHash(url string) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
|
||||
fileContent := []byte("test file content")
|
||||
err := os.WriteFile(file, fileContent, 0640)
|
||||
require.NoError(t, err)
|
||||
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||
require.NoError(t, err)
|
||||
id := getURLHash(testURL)
|
||||
require.Contains(t, key, id+"/")
|
||||
@@ -60,7 +60,7 @@ func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
|
||||
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
|
||||
if peer.PresharedKey {
|
||||
if peer.PresharedKey != [32]byte{} {
|
||||
sb.WriteString(" preshared key: (hidden)\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig DeviceAuthProviderConfig
|
||||
}
|
||||
|
||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||
type DeviceAuthProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
||||
// validate our peer's Wireguard PRIVATE key
|
||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
var mgmTLSEnabled bool
|
||||
if mgmURL.Scheme == "https" {
|
||||
mgmTLSEnabled = true
|
||||
}
|
||||
|
||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||
if err != nil {
|
||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
log.Errorf("failed to retrieve device flow: %v", err)
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: DeviceAuthProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
},
|
||||
}
|
||||
|
||||
// keep compatibility with older management versions
|
||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
}
|
||||
|
||||
return deviceAuthorizationFlow, nil
|
||||
}
|
||||
|
||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||
if config.Audience == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||
}
|
||||
if config.ClientID == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||
}
|
||||
if config.TokenEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||
}
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.SkipPTRProcess {
|
||||
if zone.NonAuthoritative {
|
||||
continue
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
|
||||
@@ -3,17 +3,21 @@ package dns
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
)
|
||||
|
||||
const (
|
||||
PriorityMgmtCache = 150
|
||||
PriorityLocal = 100
|
||||
PriorityDNSRoute = 75
|
||||
PriorityDNSRoute = 100
|
||||
PriorityLocal = 75
|
||||
PriorityUpstream = 50
|
||||
PriorityDefault = 1
|
||||
PriorityFallback = -100
|
||||
@@ -43,7 +47,23 @@ type HandlerChain struct {
|
||||
type ResponseWriterChain struct {
|
||||
dns.ResponseWriter
|
||||
origPattern string
|
||||
requestID string
|
||||
shouldContinue bool
|
||||
response *dns.Msg
|
||||
meta map[string]string
|
||||
}
|
||||
|
||||
// RequestID returns the request ID for tracing
|
||||
func (w *ResponseWriterChain) RequestID() string {
|
||||
return w.requestID
|
||||
}
|
||||
|
||||
// SetMeta sets a metadata key-value pair for logging
|
||||
func (w *ResponseWriterChain) SetMeta(key, value string) {
|
||||
if w.meta == nil {
|
||||
w.meta = make(map[string]string)
|
||||
}
|
||||
w.meta[key] = value
|
||||
}
|
||||
|
||||
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
@@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
w.shouldContinue = true
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -101,6 +122,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
|
||||
pos := c.findHandlerPosition(entry)
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
|
||||
c.logHandlers()
|
||||
}
|
||||
|
||||
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||
@@ -140,68 +163,109 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
c.logHandlers()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// logHandlers logs the current handler chain state. Caller must hold the lock.
|
||||
func (c *HandlerChain) logHandlers() {
|
||||
if !log.IsLevelEnabled(log.TraceLevel) {
|
||||
return
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
|
||||
for _, h := range c.handlers {
|
||||
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
|
||||
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
|
||||
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
|
||||
" priority=" + strconv.Itoa(h.Priority) + "\n")
|
||||
}
|
||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
qname := strings.ToLower(r.Question[0].Name)
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
|
||||
c.mu.RLock()
|
||||
handlers := slices.Clone(c.handlers)
|
||||
c.mu.RUnlock()
|
||||
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
|
||||
for _, h := range handlers {
|
||||
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
|
||||
}
|
||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||
}
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
matched := c.isHandlerMatch(qname, entry)
|
||||
|
||||
if matched {
|
||||
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
// Only log continue for non-management cache handlers to reduce noise
|
||||
if entry.Priority != PriorityMgmtCache {
|
||||
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
|
||||
handlerName := entry.OrigPattern
|
||||
if s, ok := entry.Handler.(interface{ String() string }); ok {
|
||||
handlerName = s.String()
|
||||
}
|
||||
|
||||
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
|
||||
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
requestID: requestID,
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
if entry.Priority != PriorityMgmtCache {
|
||||
logger.Tracef("handler requested continue for domain=%s", qname)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
c.logResponse(logger, chainWriter, qname, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
// No handler matched or all handlers passed
|
||||
log.Tracef("no handler found for domain=%s", qname)
|
||||
logger.Tracef("no handler found for domain=%s type=%s class=%s",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeRefused)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
|
||||
if cw.response == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var meta string
|
||||
for k, v := range cw.meta {
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
|
||||
@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD exact match",
|
||||
handlerDomain: "example.x.",
|
||||
queryDomain: "example.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD subdomain match",
|
||||
handlerDomain: "example.x.",
|
||||
queryDomain: "sub.example.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single letter TLD wildcard match",
|
||||
handlerDomain: "*.example.x.",
|
||||
queryDomain: "sub.example.x.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "two letter domain labels",
|
||||
handlerDomain: "a.b.",
|
||||
queryDomain: "a.b.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single character domain",
|
||||
handlerDomain: "x.",
|
||||
queryDomain: "x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single character domain with subdomain match",
|
||||
handlerDomain: "x.",
|
||||
queryDomain: "sub.x.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
@@ -38,6 +40,9 @@ const (
|
||||
type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
systemDNSSettings SystemDNSSettings
|
||||
|
||||
mu sync.RWMutex
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
func newHostManager() (*systemConfigurator, error) {
|
||||
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
}
|
||||
|
||||
var dnsSettings SystemDNSSettings
|
||||
var serverAddresses []netip.Addr
|
||||
inSearchDomainsArray := false
|
||||
inServerAddressesArray := false
|
||||
|
||||
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||
} else if inServerAddressesArray {
|
||||
address := strings.Split(line, " : ")[1]
|
||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip.Unmap()
|
||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
||||
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||
ip = ip.Unmap()
|
||||
serverAddresses = append(serverAddresses, ip)
|
||||
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||
dnsSettings.ServerIP = ip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
// default to 53 port
|
||||
dnsSettings.ServerPort = DefaultPort
|
||||
|
||||
s.mu.Lock()
|
||||
s.origNameservers = serverAddresses
|
||||
s.mu.Unlock()
|
||||
|
||||
return dnsSettings, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return slices.Clone(s.origNameservers)
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
|
||||
@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
|
||||
_, err := cmd.CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func TestGetOriginalNameservers(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
origNameservers: []netip.Addr{
|
||||
netip.MustParseAddr("8.8.8.8"),
|
||||
netip.MustParseAddr("1.1.1.1"),
|
||||
},
|
||||
}
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
assert.Len(t, servers, 2)
|
||||
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
|
||||
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
|
||||
}
|
||||
|
||||
func TestGetOriginalNameserversFromSystem(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
|
||||
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
|
||||
|
||||
for _, server := range servers {
|
||||
assert.True(t, server.IsValid(), "server address should be valid")
|
||||
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
|
||||
}
|
||||
|
||||
t.Logf("found %d original nameservers: %v", len(servers), servers)
|
||||
}
|
||||
|
||||
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
stateFile := filepath.Join(tmpDir, "state.json")
|
||||
sm := statemanager.New(stateFile)
|
||||
sm.RegisterState(&ShutdownState{})
|
||||
sm.Start()
|
||||
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
cleanup := func() {
|
||||
_ = sm.Stop(context.Background())
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}
|
||||
|
||||
return configurator, sm, cleanup
|
||||
}
|
||||
|
||||
func TestOriginalNameserversNoTransition(t *testing.T) {
|
||||
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
routeAll bool
|
||||
}{
|
||||
{"routeall_false", false},
|
||||
{"routeall_true", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
initialServers := configurator.getOriginalNameservers()
|
||||
t.Logf("Initial servers: %v", initialServers)
|
||||
require.NotEmpty(t, initialServers)
|
||||
|
||||
for _, srv := range initialServers {
|
||||
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netbirdIP,
|
||||
ServerPort: 53,
|
||||
RouteAll: tc.routeAll,
|
||||
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||
}
|
||||
|
||||
for i := 1; i <= 2; i++ {
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
|
||||
servers := configurator.getOriginalNameservers()
|
||||
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
|
||||
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialRoute bool
|
||||
}{
|
||||
{"start_with_routeall_false", false},
|
||||
{"start_with_routeall_true", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := configurator.getSystemDNSSettings()
|
||||
require.NoError(t, err)
|
||||
initialServers := configurator.getOriginalNameservers()
|
||||
t.Logf("Initial servers: %v", initialServers)
|
||||
require.NotEmpty(t, initialServers)
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netbirdIP,
|
||||
ServerPort: 53,
|
||||
RouteAll: tc.initialRoute,
|
||||
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||
}
|
||||
|
||||
// First apply
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers := configurator.getOriginalNameservers()
|
||||
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
// Toggle RouteAll
|
||||
config.RouteAll = !tc.initialRoute
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers = configurator.getOriginalNameservers()
|
||||
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
// Toggle back
|
||||
config.RouteAll = tc.initialRoute
|
||||
err = configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
servers = configurator.getOriginalNameservers()
|
||||
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
|
||||
assert.Equal(t, initialServers, servers)
|
||||
|
||||
for _, srv := range servers {
|
||||
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,30 +1,52 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const externalResolutionTimeout = 4 * time.Second
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
domains map[domain.Domain]struct{}
|
||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||
zones map[domain.Domain]bool
|
||||
resolver resolver
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewResolver() *Resolver {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
domains: make(map[domain.Domain]struct{}),
|
||||
zones: make(map[domain.Domain]bool),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +59,18 @@ func (d *Resolver) String() string {
|
||||
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
|
||||
}
|
||||
|
||||
func (d *Resolver) Stop() {}
|
||||
func (d *Resolver) Stop() {
|
||||
if d.cancel != nil {
|
||||
d.cancel()
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
maps.Clear(d.domains)
|
||||
maps.Clear(d.zones)
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *Resolver) ID() types.HandlerID {
|
||||
@@ -48,60 +81,150 @@ func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GetRequestID(w),
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
log.Debugf("received local resolver request with no question")
|
||||
logger.Debug("received local resolver request with no question")
|
||||
return
|
||||
}
|
||||
question := r.Question[0]
|
||||
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(question)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
// Check if we have any records for this domain name with different types
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
|
||||
}
|
||||
result := d.lookupRecords(logger, question)
|
||||
replyMessage.Authoritative = !result.hasExternalData
|
||||
replyMessage.Answer = result.records
|
||||
replyMessage.Rcode = d.determineRcode(question, result)
|
||||
|
||||
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
|
||||
d.continueToNext(logger, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(replyMessage); err != nil {
|
||||
log.Warnf("failed to write the local resolver response: %v", err)
|
||||
logger.Warnf("failed to write the local resolver response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// determineRcode returns the appropriate DNS response code.
|
||||
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
|
||||
// even if CNAME records are included in the answer.
|
||||
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
|
||||
// Use the rcode from lookup - this properly handles CNAME chains where
|
||||
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
|
||||
if result.rcode != 0 {
|
||||
return result.rcode
|
||||
}
|
||||
|
||||
// No records found, but domain exists with different record types (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
return dns.RcodeNameError
|
||||
}
|
||||
|
||||
// findZone finds the matching zone for a query name using reverse suffix lookup.
|
||||
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
|
||||
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
|
||||
qname = strings.ToLower(dns.Fqdn(qname))
|
||||
for {
|
||||
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
|
||||
return nonAuth, true
|
||||
}
|
||||
// Move to parent domain
|
||||
idx := strings.Index(qname, ".")
|
||||
if idx == -1 || idx == len(qname)-1 {
|
||||
return false, false
|
||||
}
|
||||
qname = qname[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// shouldFallthrough checks if the query should fallthrough to the next handler.
|
||||
// Returns true if the queried name belongs to a non-authoritative zone.
|
||||
func (d *Resolver) shouldFallthrough(qname string) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
nonAuth, found := d.findZone(qname)
|
||||
return found && nonAuth
|
||||
}
|
||||
|
||||
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Warnf("failed to write continue signal: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
|
||||
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
||||
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
_, exists := d.domains[domainName]
|
||||
if !exists && supportsWildcard(qType) {
|
||||
testWild := transformDomainToWildcard(string(domainName))
|
||||
_, exists = d.domains[domain.Domain(testWild)]
|
||||
}
|
||||
return exists
|
||||
}
|
||||
|
||||
// isInManagedZone checks if the given name falls within any of our managed zones.
|
||||
// This is used to avoid unnecessary external resolution for CNAME targets that
|
||||
// are within zones we manage - if we don't have a record for it, it doesn't exist.
|
||||
// Caller must NOT hold the lock.
|
||||
func (d *Resolver) isInManagedZone(name string) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
_, found := d.findZone(name)
|
||||
return found
|
||||
}
|
||||
|
||||
// lookupResult contains the result of a DNS lookup operation.
|
||||
type lookupResult struct {
|
||||
records []dns.RR
|
||||
rcode int
|
||||
hasExternalData bool
|
||||
}
|
||||
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
|
||||
d.mu.RLock()
|
||||
records, found := d.records[question]
|
||||
usingWildcard := false
|
||||
wildQuestion := transformToWildcard(question)
|
||||
// RFC 4592 section 2.2.1: wildcard only matches if the name does NOT exist in the zone.
|
||||
// If the domain exists with any record type, return NODATA instead of wildcard match.
|
||||
if !found && supportsWildcard(question.Qtype) {
|
||||
if _, domainExists := d.domains[domain.Domain(question.Name)]; !domainExists {
|
||||
records, found = d.records[wildQuestion]
|
||||
usingWildcard = found
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
d.mu.RUnlock()
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
question.Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(question)
|
||||
cnameQuestion := dns.Question{
|
||||
Name: question.Name,
|
||||
Qtype: dns.TypeCNAME,
|
||||
Qclass: question.Qclass,
|
||||
}
|
||||
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
|
||||
}
|
||||
return nil
|
||||
return lookupResult{rcode: dns.RcodeNameError}
|
||||
}
|
||||
|
||||
recordsCopy := slices.Clone(records)
|
||||
@@ -110,29 +233,229 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
// if there's more than one record, rotate them (round-robin)
|
||||
if len(recordsCopy) > 1 {
|
||||
d.mu.Lock()
|
||||
records = d.records[question]
|
||||
q := question
|
||||
if usingWildcard {
|
||||
q = wildQuestion
|
||||
}
|
||||
records = d.records[q]
|
||||
if len(records) > 1 {
|
||||
first := records[0]
|
||||
records = append(records[1:], first)
|
||||
d.records[question] = records
|
||||
d.records[q] = records
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
return recordsCopy
|
||||
if usingWildcard {
|
||||
return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy)
|
||||
}
|
||||
|
||||
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
||||
func transformToWildcard(question dns.Question) dns.Question {
|
||||
wildQuestion := question
|
||||
wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name)
|
||||
return wildQuestion
|
||||
}
|
||||
|
||||
func transformDomainToWildcard(domain string) string {
|
||||
s := strings.Split(domain, ".")
|
||||
s[0] = "*"
|
||||
return strings.Join(s, ".")
|
||||
}
|
||||
|
||||
func supportsWildcard(queryType uint16) bool {
|
||||
return queryType != dns.TypeNS && queryType != dns.TypeSOA
|
||||
}
|
||||
|
||||
func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult {
|
||||
records := make([]dns.RR, len(wildRecords))
|
||||
for i, record := range wildRecords {
|
||||
copiedRecord := dns.Copy(record)
|
||||
copiedRecord.Header().Name = originalName
|
||||
records[i] = copiedRecord
|
||||
}
|
||||
|
||||
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
|
||||
// the final resolved record of the requested type. This is required for musl libc
|
||||
// compatibility, which expects the full answer chain rather than just the CNAME.
|
||||
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
|
||||
const maxDepth = 8
|
||||
var chain []dns.RR
|
||||
|
||||
for range maxDepth {
|
||||
cnameRecords := d.getRecords(cnameQuestion)
|
||||
if len(cnameRecords) == 0 && supportsWildcard(targetType) {
|
||||
wildQuestion := transformToWildcard(cnameQuestion)
|
||||
if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 {
|
||||
cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records
|
||||
}
|
||||
}
|
||||
|
||||
if len(cnameRecords) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
chain = append(chain, cnameRecords...)
|
||||
|
||||
cname, ok := cnameRecords[0].(*dns.CNAME)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
targetName := strings.ToLower(cname.Target)
|
||||
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
|
||||
|
||||
// keep following chain
|
||||
if targetResult.rcode == -1 {
|
||||
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
|
||||
continue
|
||||
}
|
||||
|
||||
return d.buildChainResult(chain, targetResult)
|
||||
}
|
||||
|
||||
if len(chain) > 0 {
|
||||
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
return lookupResult{rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// buildChainResult combines CNAME chain records with the target resolution result.
|
||||
// Per RFC 6604, the final rcode is propagated through the chain.
|
||||
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
|
||||
records := chain
|
||||
if len(target.records) > 0 {
|
||||
records = append(records, target.records...)
|
||||
}
|
||||
|
||||
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
|
||||
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
|
||||
return lookupResult{
|
||||
records: records,
|
||||
rcode: dns.RcodeServerFailure,
|
||||
hasExternalData: true,
|
||||
}
|
||||
}
|
||||
|
||||
return lookupResult{
|
||||
records: records,
|
||||
rcode: target.rcode,
|
||||
hasExternalData: target.hasExternalData,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCNAMETarget attempts to resolve a CNAME target name.
|
||||
// Returns rcode=-1 to signal "keep following the chain".
|
||||
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
|
||||
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
|
||||
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// another CNAME, keep following
|
||||
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
|
||||
return lookupResult{rcode: -1}
|
||||
}
|
||||
|
||||
// domain exists locally but not this record type (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(targetName), targetType) {
|
||||
return lookupResult{rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// in our zone but doesn't exist (NXDOMAIN)
|
||||
if d.isInManagedZone(targetName) {
|
||||
return lookupResult{rcode: dns.RcodeNameError}
|
||||
}
|
||||
|
||||
return d.resolveExternal(logger, targetName, targetType)
|
||||
}
|
||||
|
||||
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
return d.records[q]
|
||||
}
|
||||
|
||||
func (d *Resolver) hasRecord(q dns.Question) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
_, ok := d.records[q]
|
||||
return ok
|
||||
}
|
||||
|
||||
// resolveExternal resolves a domain name using the system resolver.
|
||||
// This is used to resolve CNAME targets that point outside our local zone,
|
||||
// which is required for musl libc compatibility (musl expects complete answers).
|
||||
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return lookupResult{rcode: dns.RcodeNotImplemented}
|
||||
}
|
||||
|
||||
resolver := d.resolver
|
||||
if resolver == nil {
|
||||
resolver = net.DefaultResolver
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
|
||||
if result.Err != nil {
|
||||
d.logDNSError(logger, name, qtype, result.Err)
|
||||
return lookupResult{rcode: result.Rcode, hasExternalData: true}
|
||||
}
|
||||
|
||||
return lookupResult{
|
||||
records: resutil.IPsToRRs(name, result.IPs, 60),
|
||||
rcode: dns.RcodeSuccess,
|
||||
hasExternalData: true,
|
||||
}
|
||||
}
|
||||
|
||||
// logDNSError logs DNS resolution errors for debugging.
|
||||
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
|
||||
qtypeName := dns.TypeToString[qtype]
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.IsNotFound {
|
||||
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
|
||||
} else {
|
||||
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update replaces all zones and their records
|
||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
maps.Clear(d.domains)
|
||||
maps.Clear(d.zones)
|
||||
|
||||
for _, rec := range update {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
continue
|
||||
for _, zone := range customZones {
|
||||
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
|
||||
d.zones[zoneDomain] = zone.NonAuthoritative
|
||||
|
||||
for _, rec := range zone.Records {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user