mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-30 22:26:42 +00:00
Compare commits
76 Commits
set-min-pa
...
v0.36.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7335c82553 | ||
|
|
a32ec97911 | ||
|
|
5c05131a94 | ||
|
|
b6abd4b4da | ||
|
|
2605948e01 | ||
|
|
eb2ac039c7 | ||
|
|
790a9ed7df | ||
|
|
2e61ce006d | ||
|
|
3cc485759e | ||
|
|
aafa9c67fc | ||
|
|
69f48db0a3 | ||
|
|
8c965434ae | ||
|
|
78da6b42ad | ||
|
|
1ad2cb5582 | ||
|
|
c619bf5b0c | ||
|
|
9f4db0a953 | ||
|
|
3e836db1d1 | ||
|
|
c01874e9ce | ||
|
|
1b2517ea20 | ||
|
|
3e9f0d57ac | ||
|
|
481bbe8513 | ||
|
|
bc7b2c6ba3 | ||
|
|
c6f7a299a9 | ||
|
|
992a6c79b4 | ||
|
|
78795a4a73 | ||
|
|
5a82477d48 | ||
|
|
1ffa519387 | ||
|
|
e4a25b6a60 | ||
|
|
6a6b527f24 | ||
|
|
b34887a920 | ||
|
|
b9efda3ce8 | ||
|
|
516de93627 | ||
|
|
15f0a665f8 | ||
|
|
9b5b632ff9 | ||
|
|
0c28099712 | ||
|
|
522dd44bfa | ||
|
|
8154069e77 | ||
|
|
e161a92898 | ||
|
|
3fce8485bb | ||
|
|
1cc88a2190 | ||
|
|
168ea9560e | ||
|
|
f48e33b395 | ||
|
|
f1ed8599fc | ||
|
|
93f3e1b14b | ||
|
|
649bfb236b | ||
|
|
409003b4f9 | ||
|
|
9e6e34b42d | ||
|
|
d9905d1a57 | ||
|
|
2bd68efc08 | ||
|
|
6848e1e128 | ||
|
|
668aead4c8 | ||
|
|
f08605a7f1 | ||
|
|
02a3feddb8 | ||
|
|
d9487a5749 | ||
|
|
cfa6d09c5e | ||
|
|
a01253c3c8 | ||
|
|
bc013e4888 | ||
|
|
782e3f8853 | ||
|
|
03fd656344 | ||
|
|
18b049cd24 | ||
|
|
2bdb4cb44a | ||
|
|
abbdf20f65 | ||
|
|
43ef64cf67 | ||
|
|
18316be09a | ||
|
|
1a623943c8 | ||
|
|
fbce8bb511 | ||
|
|
445b626dc8 | ||
|
|
b3c87cb5d1 | ||
|
|
0dbaddc7be | ||
|
|
ad9f044aad | ||
|
|
05930ee6b1 | ||
|
|
e670068cab | ||
|
|
b48cf1bf65 | ||
|
|
7ee7ada273 | ||
|
|
82b4e58ad0 | ||
|
|
ddc365f7a0 |
@@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.21-bullseye
|
FROM golang:1.23-bullseye
|
||||||
|
|
||||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||||
&& apt-get -y install --no-install-recommends\
|
&& apt-get -y install --no-install-recommends\
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
"features": {
|
"features": {
|
||||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
||||||
"ghcr.io/devcontainers/features/go:1": {
|
"ghcr.io/devcontainers/features/go:1": {
|
||||||
"version": "1.21"
|
"version": "1.23"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
||||||
|
|||||||
3
.github/workflows/golang-test-darwin.yml
vendored
3
.github/workflows/golang-test-darwin.yml
vendored
@@ -44,4 +44,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||||
|
|
||||||
|
|||||||
4
.github/workflows/golang-test-freebsd.yml
vendored
4
.github/workflows/golang-test-freebsd.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
copyback: false
|
copyback: false
|
||||||
release: "14.1"
|
release: "14.1"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y go
|
pkg install -y go pkgconf xorg
|
||||||
|
|
||||||
# -x - to print all executed commands
|
# -x - to print all executed commands
|
||||||
# -e - to faile on first error
|
# -e - to faile on first error
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
time go build -o netbird client/main.go
|
time go build -o netbird client/main.go
|
||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -timeout 1m -failfast ./base62/...
|
||||||
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
|
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||||
time go test -timeout 1m -failfast ./dns/...
|
time go test -timeout 1m -failfast ./dns/...
|
||||||
time go test -timeout 1m -failfast ./encryption/...
|
time go test -timeout 1m -failfast ./encryption/...
|
||||||
|
|||||||
235
.github/workflows/golang-test-linux.yml
vendored
235
.github/workflows/golang-test-linux.yml
vendored
@@ -13,7 +13,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
build-cache:
|
build-cache:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
@@ -134,9 +134,189 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||||
|
|
||||||
test_management:
|
test_management:
|
||||||
|
needs: [ build-cache ]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Login to Docker hub
|
||||||
|
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: download mysql image
|
||||||
|
if: matrix.store == 'mysql'
|
||||||
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||||
|
|
||||||
|
benchmark:
|
||||||
|
needs: [ build-cache ]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Login to Docker hub
|
||||||
|
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: download mysql image
|
||||||
|
if: matrix.store == 'mysql'
|
||||||
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
||||||
|
|
||||||
|
api_benchmark:
|
||||||
|
needs: [ build-cache ]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Login to Docker hub
|
||||||
|
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: download mysql image
|
||||||
|
if: matrix.store == 'mysql'
|
||||||
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
|
||||||
|
|
||||||
|
api_integration_test:
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@@ -183,56 +363,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
|
||||||
|
|
||||||
benchmark:
|
|
||||||
needs: [ build-cache ]
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
arch: [ '386','amd64' ]
|
|
||||||
store: [ 'sqlite', 'postgres' ]
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version: "1.23.x"
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
|
||||||
run: |
|
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache/restore@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
${{ env.cache }}
|
|
||||||
${{ env.modcache }}
|
|
||||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-gotest-cache-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
|
||||||
run: go mod tidy
|
|
||||||
|
|
||||||
- name: check git status
|
|
||||||
run: git --no-pager diff --exit-code
|
|
||||||
|
|
||||||
- name: Test
|
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m ./...
|
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -65,7 +65,7 @@ jobs:
|
|||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
- name: test output
|
- name: test output
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: Get-Content test-out.txt
|
run: Get-Content test-out.txt
|
||||||
|
|||||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
@@ -46,7 +46,7 @@ jobs:
|
|||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v4
|
||||||
with:
|
with:
|
||||||
version: latest
|
version: latest
|
||||||
args: --timeout=12m
|
args: --timeout=12m --out-format colored-line-number
|
||||||
|
|||||||
23
.github/workflows/test-infrastructure-files.yml
vendored
23
.github/workflows/test-infrastructure-files.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
store: [ 'sqlite', 'postgres' ]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||||
@@ -34,6 +34,19 @@ jobs:
|
|||||||
--health-timeout 5s
|
--health-timeout 5s
|
||||||
ports:
|
ports:
|
||||||
- 5432:5432
|
- 5432:5432
|
||||||
|
mysql:
|
||||||
|
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
|
||||||
|
env:
|
||||||
|
MYSQL_USER: netbird
|
||||||
|
MYSQL_PASSWORD: mysql
|
||||||
|
MYSQL_ROOT_PASSWORD: mysqlroot
|
||||||
|
MYSQL_DATABASE: netbird
|
||||||
|
options: >-
|
||||||
|
--health-cmd "mysqladmin ping --silent"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
ports:
|
||||||
|
- 3306:3306
|
||||||
steps:
|
steps:
|
||||||
- name: Set Database Connection String
|
- name: Set Database Connection String
|
||||||
run: |
|
run: |
|
||||||
@@ -42,6 +55,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
|
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
|
||||||
fi
|
fi
|
||||||
|
if [ "${{ matrix.store }}" == "mysql" ]; then
|
||||||
|
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
|
||||||
|
else
|
||||||
|
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Install jq
|
- name: Install jq
|
||||||
run: sudo apt-get install -y jq
|
run: sudo apt-get install -y jq
|
||||||
@@ -84,6 +102,7 @@ jobs:
|
|||||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
|
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
|
||||||
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
|
|
||||||
- name: check values
|
- name: check values
|
||||||
@@ -112,6 +131,7 @@ jobs:
|
|||||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||||
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||||
|
|
||||||
@@ -149,6 +169,7 @@ jobs:
|
|||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
||||||
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
||||||
|
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||||
# check relay values
|
# check relay values
|
||||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||||
|
|||||||
@@ -179,6 +179,51 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
ids:
|
||||||
|
- netbird
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: client/Dockerfile-rootless
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: client/Dockerfile-rootless
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
ids:
|
||||||
|
- netbird
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: client/Dockerfile-rootless
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-amd64
|
- netbirdio/relay:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
@@ -377,6 +422,18 @@ docker_manifests:
|
|||||||
- netbirdio/netbird:{{ .Version }}-arm
|
- netbirdio/netbird:{{ .Version }}-arm
|
||||||
- netbirdio/netbird:{{ .Version }}-amd64
|
- netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird:{{ .Version }}-rootless
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird:rootless-latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
- name_template: netbirdio/relay:{{ .Version }}
|
- name_template: netbirdio/relay:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
|
|||||||
@@ -1,10 +1,3 @@
|
|||||||
<p align="center">
|
|
||||||
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
|
|
||||||
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
|
|
||||||
Learn more
|
|
||||||
</a>
|
|
||||||
</p>
|
|
||||||
<br/>
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img width="234" src="docs/media/logo-full.png"/>
|
<img width="234" src="docs/media/logo-full.png"/>
|
||||||
|
|||||||
16
client/Dockerfile-rootless
Normal file
16
client/Dockerfile-rootless
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
FROM alpine:3.21.0
|
||||||
|
|
||||||
|
COPY netbird /usr/local/bin/netbird
|
||||||
|
|
||||||
|
RUN apk add --no-cache ca-certificates \
|
||||||
|
&& adduser -D -h /var/lib/netbird netbird
|
||||||
|
WORKDIR /var/lib/netbird
|
||||||
|
USER netbird:netbird
|
||||||
|
|
||||||
|
ENV NB_FOREGROUND_MODE=true
|
||||||
|
ENV NB_USE_NETSTACK_MODE=true
|
||||||
|
ENV NB_CONFIG=config.json
|
||||||
|
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||||
|
ENV NB_DISABLE_DNS=true
|
||||||
|
|
||||||
|
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]
|
||||||
@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ type Anonymizer struct {
|
|||||||
currentAnonIPv6 netip.Addr
|
currentAnonIPv6 netip.Addr
|
||||||
startAnonIPv4 netip.Addr
|
startAnonIPv4 netip.Addr
|
||||||
startAnonIPv6 netip.Addr
|
startAnonIPv6 netip.Addr
|
||||||
|
|
||||||
|
domainKeyRegex *regexp.Regexp
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||||
@@ -36,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
|||||||
currentAnonIPv6: startIPv6,
|
currentAnonIPv6: startIPv6,
|
||||||
startAnonIPv4: startIPv4,
|
startAnonIPv4: startIPv4,
|
||||||
startAnonIPv6: startIPv6,
|
startAnonIPv6: startIPv6,
|
||||||
|
|
||||||
|
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,20 +175,15 @@ func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
|||||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
|
|
||||||
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||||
domainPattern := `dns\.Question{Name:"([^"]+)",`
|
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||||
domainRegex := regexp.MustCompile(domainPattern)
|
parts := strings.SplitN(match, "=", 2)
|
||||||
|
|
||||||
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
|
||||||
parts := strings.Split(match, `"`)
|
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
domain := parts[1]
|
domain := parts[1]
|
||||||
if strings.HasSuffix(domain, anonTLD) {
|
if strings.HasSuffix(domain, anonTLD) {
|
||||||
return match
|
return match
|
||||||
}
|
}
|
||||||
randomDomain := generateRandomString(10) + anonTLD
|
return "domain=" + a.AnonymizeDomain(domain)
|
||||||
return strings.Replace(match, domain, randomDomain, 1)
|
|
||||||
}
|
}
|
||||||
return match
|
return match
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) {
|
|||||||
|
|
||||||
func TestAnonymizeDNSLogLine(t *testing.T) {
|
func TestAnonymizeDNSLogLine(t *testing.T) {
|
||||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||||
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
original string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic domain with trailing content",
|
||||||
|
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Domain with trailing dot",
|
||||||
|
input: "domain=example.com. processing request with status=pending",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple domains in log",
|
||||||
|
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
|
||||||
|
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
|
||||||
|
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Already anonymized domain",
|
||||||
|
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
|
||||||
|
original: "", // nothing should be anonymized
|
||||||
|
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Subdomain with trailing dot",
|
||||||
|
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Handler chain pattern log",
|
||||||
|
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
result := anonymizer.AnonymizeDNSLogLine(testLog)
|
for _, tc := range tests {
|
||||||
require.NotEqual(t, testLog, result)
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.NotContains(t, result, "example.com")
|
result := anonymizer.AnonymizeDNSLogLine(tc.input)
|
||||||
|
if tc.original != "" {
|
||||||
|
assert.NotContains(t, result, tc.original)
|
||||||
|
}
|
||||||
|
assert.Regexp(t, tc.expect, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymizeDomain(t *testing.T) {
|
func TestAnonymizeDomain(t *testing.T) {
|
||||||
|
|||||||
173
client/cmd/networks.go
Normal file
173
client/cmd/networks.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var appendFlag bool
|
||||||
|
|
||||||
|
var networksCMD = &cobra.Command{
|
||||||
|
Use: "networks",
|
||||||
|
Aliases: []string{"routes"},
|
||||||
|
Short: "Manage networks",
|
||||||
|
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var routesListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List networks",
|
||||||
|
Example: " netbird networks list",
|
||||||
|
Long: "List all available network routes.",
|
||||||
|
RunE: networksList,
|
||||||
|
}
|
||||||
|
|
||||||
|
var routesSelectCmd = &cobra.Command{
|
||||||
|
Use: "select network...|all",
|
||||||
|
Short: "Select network",
|
||||||
|
Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.",
|
||||||
|
Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3",
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: networksSelect,
|
||||||
|
}
|
||||||
|
|
||||||
|
var routesDeselectCmd = &cobra.Command{
|
||||||
|
Use: "deselect network...|all",
|
||||||
|
Short: "Deselect networks",
|
||||||
|
Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.",
|
||||||
|
Example: " netbird networks deselect all\n netbird networks deselect route1 route2",
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: networksDeselect,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
|
||||||
|
}
|
||||||
|
|
||||||
|
func networksList(cmd *cobra.Command, _ []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Routes) == 0 {
|
||||||
|
cmd.Println("No networks available.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
printNetworks(cmd, resp)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
|
||||||
|
cmd.Println("Available Networks:")
|
||||||
|
for _, route := range resp.Routes {
|
||||||
|
printNetwork(cmd, route)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func printNetwork(cmd *cobra.Command, route *proto.Network) {
|
||||||
|
selectedStatus := getSelectedStatus(route)
|
||||||
|
domains := route.GetDomains()
|
||||||
|
|
||||||
|
if len(domains) > 0 {
|
||||||
|
printDomainRoute(cmd, route, domains, selectedStatus)
|
||||||
|
} else {
|
||||||
|
printNetworkRoute(cmd, route, selectedStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSelectedStatus(route *proto.Network) string {
|
||||||
|
if route.GetSelected() {
|
||||||
|
return "Selected"
|
||||||
|
}
|
||||||
|
return "Not Selected"
|
||||||
|
}
|
||||||
|
|
||||||
|
func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) {
|
||||||
|
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
||||||
|
resolvedIPs := route.GetResolvedIPs()
|
||||||
|
|
||||||
|
if len(resolvedIPs) > 0 {
|
||||||
|
printResolvedIPs(cmd, domains, resolvedIPs)
|
||||||
|
} else {
|
||||||
|
cmd.Printf(" Resolved IPs: -\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) {
|
||||||
|
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
|
||||||
|
cmd.Printf(" Resolved IPs:\n")
|
||||||
|
for resolvedDomain, ipList := range resolvedIPs {
|
||||||
|
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func networksSelect(cmd *cobra.Command, args []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
req := &proto.SelectNetworksRequest{
|
||||||
|
NetworkIDs: args,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 1 && args[0] == "all" {
|
||||||
|
req.All = true
|
||||||
|
} else if appendFlag {
|
||||||
|
req.Append = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.SelectNetworks(cmd.Context(), req); err != nil {
|
||||||
|
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("Networks selected successfully.")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func networksDeselect(cmd *cobra.Command, args []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
req := &proto.SelectNetworksRequest{
|
||||||
|
NetworkIDs: args,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 1 && args[0] == "all" {
|
||||||
|
req.All = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil {
|
||||||
|
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("Networks deselected successfully.")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -38,6 +38,7 @@ const (
|
|||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
systemInfoFlag = "system-info"
|
systemInfoFlag = "system-info"
|
||||||
|
blockLANAccessFlag = "block-lan-access"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -73,6 +74,7 @@ var (
|
|||||||
anonymizeFlag bool
|
anonymizeFlag bool
|
||||||
debugSystemInfoFlag bool
|
debugSystemInfoFlag bool
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
|
blockLANAccess bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
@@ -142,14 +144,14 @@ func init() {
|
|||||||
rootCmd.AddCommand(loginCmd)
|
rootCmd.AddCommand(loginCmd)
|
||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
rootCmd.AddCommand(sshCmd)
|
rootCmd.AddCommand(sshCmd)
|
||||||
rootCmd.AddCommand(routesCmd)
|
rootCmd.AddCommand(networksCMD)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
|
||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||||
|
|
||||||
routesCmd.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
|
|
||||||
debugCmd.AddCommand(debugBundleCmd)
|
debugCmd.AddCommand(debugBundleCmd)
|
||||||
debugCmd.AddCommand(logCmd)
|
debugCmd.AddCommand(logCmd)
|
||||||
|
|||||||
@@ -1,174 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
var appendFlag bool
|
|
||||||
|
|
||||||
var routesCmd = &cobra.Command{
|
|
||||||
Use: "routes",
|
|
||||||
Short: "Manage network routes",
|
|
||||||
Long: `Commands to list, select, or deselect network routes.`,
|
|
||||||
}
|
|
||||||
|
|
||||||
var routesListCmd = &cobra.Command{
|
|
||||||
Use: "list",
|
|
||||||
Aliases: []string{"ls"},
|
|
||||||
Short: "List routes",
|
|
||||||
Example: " netbird routes list",
|
|
||||||
Long: "List all available network routes.",
|
|
||||||
RunE: routesList,
|
|
||||||
}
|
|
||||||
|
|
||||||
var routesSelectCmd = &cobra.Command{
|
|
||||||
Use: "select route...|all",
|
|
||||||
Short: "Select routes",
|
|
||||||
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
|
|
||||||
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
|
|
||||||
Args: cobra.MinimumNArgs(1),
|
|
||||||
RunE: routesSelect,
|
|
||||||
}
|
|
||||||
|
|
||||||
var routesDeselectCmd = &cobra.Command{
|
|
||||||
Use: "deselect route...|all",
|
|
||||||
Short: "Deselect routes",
|
|
||||||
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
|
|
||||||
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
|
|
||||||
Args: cobra.MinimumNArgs(1),
|
|
||||||
RunE: routesDeselect,
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
|
|
||||||
}
|
|
||||||
|
|
||||||
func routesList(cmd *cobra.Command, _ []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Routes) == 0 {
|
|
||||||
cmd.Println("No routes available.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
printRoutes(cmd, resp)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
|
|
||||||
cmd.Println("Available Routes:")
|
|
||||||
for _, route := range resp.Routes {
|
|
||||||
printRoute(cmd, route)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func printRoute(cmd *cobra.Command, route *proto.Route) {
|
|
||||||
selectedStatus := getSelectedStatus(route)
|
|
||||||
domains := route.GetDomains()
|
|
||||||
|
|
||||||
if len(domains) > 0 {
|
|
||||||
printDomainRoute(cmd, route, domains, selectedStatus)
|
|
||||||
} else {
|
|
||||||
printNetworkRoute(cmd, route, selectedStatus)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSelectedStatus(route *proto.Route) string {
|
|
||||||
if route.GetSelected() {
|
|
||||||
return "Selected"
|
|
||||||
}
|
|
||||||
return "Not Selected"
|
|
||||||
}
|
|
||||||
|
|
||||||
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
|
|
||||||
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
|
||||||
resolvedIPs := route.GetResolvedIPs()
|
|
||||||
|
|
||||||
if len(resolvedIPs) > 0 {
|
|
||||||
printResolvedIPs(cmd, domains, resolvedIPs)
|
|
||||||
} else {
|
|
||||||
cmd.Printf(" Resolved IPs: -\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
|
|
||||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
|
||||||
cmd.Printf(" Resolved IPs:\n")
|
|
||||||
for _, domain := range domains {
|
|
||||||
if ipList, exists := resolvedIPs[domain]; exists {
|
|
||||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
req := &proto.SelectRoutesRequest{
|
|
||||||
RouteIDs: args,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) == 1 && args[0] == "all" {
|
|
||||||
req.All = true
|
|
||||||
} else if appendFlag {
|
|
||||||
req.Append = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
|
|
||||||
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Println("Routes selected successfully.")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func routesDeselect(cmd *cobra.Command, args []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
req := &proto.SelectRoutesRequest{
|
|
||||||
RouteIDs: args,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) == 1 && args[0] == "all" {
|
|
||||||
req.All = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
|
|
||||||
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Println("Routes deselected successfully.")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@@ -73,7 +72,7 @@ var sshCmd = &cobra.Command{
|
|||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||||
log.Debug(err)
|
cmd.Printf("Error: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ type peerStateDetailOutput struct {
|
|||||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
Latency time.Duration `json:"latency" yaml:"latency"`
|
||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
Routes []string `json:"routes" yaml:"routes"`
|
||||||
|
Networks []string `json:"networks" yaml:"networks"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type peersStateOutput struct {
|
type peersStateOutput struct {
|
||||||
@@ -98,6 +99,7 @@ type statusOutputOverview struct {
|
|||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
Routes []string `json:"routes" yaml:"routes"`
|
||||||
|
Networks []string `json:"networks" yaml:"networks"`
|
||||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,7 +284,8 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
|
|||||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||||
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
|
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||||
|
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -390,7 +393,8 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
|||||||
TransferSent: transferSent,
|
TransferSent: transferSent,
|
||||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
Latency: pbPeerState.GetLatency().AsDuration(),
|
||||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||||
Routes: pbPeerState.GetRoutes(),
|
Routes: pbPeerState.GetNetworks(),
|
||||||
|
Networks: pbPeerState.GetNetworks(),
|
||||||
}
|
}
|
||||||
|
|
||||||
peersStateDetail = append(peersStateDetail, peerState)
|
peersStateDetail = append(peersStateDetail, peerState)
|
||||||
@@ -491,10 +495,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
|||||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := "-"
|
networks := "-"
|
||||||
if len(overview.Routes) > 0 {
|
if len(overview.Networks) > 0 {
|
||||||
sort.Strings(overview.Routes)
|
sort.Strings(overview.Networks)
|
||||||
routes = strings.Join(overview.Routes, ", ")
|
networks = strings.Join(overview.Networks, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
var dnsServersString string
|
var dnsServersString string
|
||||||
@@ -556,6 +560,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
|||||||
"Interface type: %s\n"+
|
"Interface type: %s\n"+
|
||||||
"Quantum resistance: %s\n"+
|
"Quantum resistance: %s\n"+
|
||||||
"Routes: %s\n"+
|
"Routes: %s\n"+
|
||||||
|
"Networks: %s\n"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||||
overview.DaemonVersion,
|
overview.DaemonVersion,
|
||||||
@@ -568,7 +573,8 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
|||||||
interfaceIP,
|
interfaceIP,
|
||||||
interfaceTypeString,
|
interfaceTypeString,
|
||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
routes,
|
networks,
|
||||||
|
networks,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
)
|
)
|
||||||
return summary
|
return summary
|
||||||
@@ -631,10 +637,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := "-"
|
networks := "-"
|
||||||
if len(peerState.Routes) > 0 {
|
if len(peerState.Networks) > 0 {
|
||||||
sort.Strings(peerState.Routes)
|
sort.Strings(peerState.Networks)
|
||||||
routes = strings.Join(peerState.Routes, ", ")
|
networks = strings.Join(peerState.Networks, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
peerString := fmt.Sprintf(
|
peerString := fmt.Sprintf(
|
||||||
@@ -652,6 +658,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
" Transfer status (received/sent) %s/%s\n"+
|
" Transfer status (received/sent) %s/%s\n"+
|
||||||
" Quantum resistance: %s\n"+
|
" Quantum resistance: %s\n"+
|
||||||
" Routes: %s\n"+
|
" Routes: %s\n"+
|
||||||
|
" Networks: %s\n"+
|
||||||
" Latency: %s\n",
|
" Latency: %s\n",
|
||||||
peerState.FQDN,
|
peerState.FQDN,
|
||||||
peerState.IP,
|
peerState.IP,
|
||||||
@@ -668,7 +675,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
toIEC(peerState.TransferReceived),
|
toIEC(peerState.TransferReceived),
|
||||||
toIEC(peerState.TransferSent),
|
toIEC(peerState.TransferSent),
|
||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
routes,
|
networks,
|
||||||
|
networks,
|
||||||
peerState.Latency.String(),
|
peerState.Latency.String(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -810,6 +818,14 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
|||||||
|
|
||||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||||
|
|
||||||
|
for i, route := range peer.Networks {
|
||||||
|
peer.Networks[i] = a.AnonymizeIPString(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, route := range peer.Networks {
|
||||||
|
peer.Networks[i] = a.AnonymizeRoute(route)
|
||||||
|
}
|
||||||
|
|
||||||
for i, route := range peer.Routes {
|
for i, route := range peer.Routes {
|
||||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||||
}
|
}
|
||||||
@@ -850,6 +866,10 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i, route := range overview.Networks {
|
||||||
|
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||||
|
}
|
||||||
|
|
||||||
for i, route := range overview.Routes {
|
for i, route := range overview.Routes {
|
||||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
overview.Routes[i] = a.AnonymizeRoute(route)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ var resp = &proto.StatusResponse{
|
|||||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||||
BytesRx: 200,
|
BytesRx: 200,
|
||||||
BytesTx: 100,
|
BytesTx: 100,
|
||||||
Routes: []string{
|
Networks: []string{
|
||||||
"10.1.0.0/24",
|
"10.1.0.0/24",
|
||||||
},
|
},
|
||||||
Latency: durationpb.New(time.Duration(10000000)),
|
Latency: durationpb.New(time.Duration(10000000)),
|
||||||
@@ -93,7 +93,7 @@ var resp = &proto.StatusResponse{
|
|||||||
PubKey: "Some-Pub-Key",
|
PubKey: "Some-Pub-Key",
|
||||||
KernelInterface: true,
|
KernelInterface: true,
|
||||||
Fqdn: "some-localhost.awesome-domain.com",
|
Fqdn: "some-localhost.awesome-domain.com",
|
||||||
Routes: []string{
|
Networks: []string{
|
||||||
"10.10.0.0/24",
|
"10.10.0.0/24",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -149,6 +149,9 @@ var overview = statusOutputOverview{
|
|||||||
Routes: []string{
|
Routes: []string{
|
||||||
"10.1.0.0/24",
|
"10.1.0.0/24",
|
||||||
},
|
},
|
||||||
|
Networks: []string{
|
||||||
|
"10.1.0.0/24",
|
||||||
|
},
|
||||||
Latency: time.Duration(10000000),
|
Latency: time.Duration(10000000),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -230,6 +233,9 @@ var overview = statusOutputOverview{
|
|||||||
Routes: []string{
|
Routes: []string{
|
||||||
"10.10.0.0/24",
|
"10.10.0.0/24",
|
||||||
},
|
},
|
||||||
|
Networks: []string{
|
||||||
|
"10.10.0.0/24",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||||
@@ -295,6 +301,9 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"quantumResistance": false,
|
"quantumResistance": false,
|
||||||
"routes": [
|
"routes": [
|
||||||
"10.1.0.0/24"
|
"10.1.0.0/24"
|
||||||
|
],
|
||||||
|
"networks": [
|
||||||
|
"10.1.0.0/24"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -318,7 +327,8 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"transferSent": 1000,
|
"transferSent": 1000,
|
||||||
"latency": 10000000,
|
"latency": 10000000,
|
||||||
"quantumResistance": false,
|
"quantumResistance": false,
|
||||||
"routes": null
|
"routes": null,
|
||||||
|
"networks": null
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -359,6 +369,9 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"routes": [
|
"routes": [
|
||||||
"10.10.0.0/24"
|
"10.10.0.0/24"
|
||||||
],
|
],
|
||||||
|
"networks": [
|
||||||
|
"10.10.0.0/24"
|
||||||
|
],
|
||||||
"dnsServers": [
|
"dnsServers": [
|
||||||
{
|
{
|
||||||
"servers": [
|
"servers": [
|
||||||
@@ -418,6 +431,8 @@ func TestParsingToYAML(t *testing.T) {
|
|||||||
quantumResistance: false
|
quantumResistance: false
|
||||||
routes:
|
routes:
|
||||||
- 10.1.0.0/24
|
- 10.1.0.0/24
|
||||||
|
networks:
|
||||||
|
- 10.1.0.0/24
|
||||||
- fqdn: peer-2.awesome-domain.com
|
- fqdn: peer-2.awesome-domain.com
|
||||||
netbirdIp: 192.168.178.102
|
netbirdIp: 192.168.178.102
|
||||||
publicKey: Pubkey2
|
publicKey: Pubkey2
|
||||||
@@ -437,6 +452,7 @@ func TestParsingToYAML(t *testing.T) {
|
|||||||
latency: 10ms
|
latency: 10ms
|
||||||
quantumResistance: false
|
quantumResistance: false
|
||||||
routes: []
|
routes: []
|
||||||
|
networks: []
|
||||||
cliVersion: development
|
cliVersion: development
|
||||||
daemonVersion: 0.14.1
|
daemonVersion: 0.14.1
|
||||||
management:
|
management:
|
||||||
@@ -465,6 +481,8 @@ quantumResistance: false
|
|||||||
quantumResistancePermissive: false
|
quantumResistancePermissive: false
|
||||||
routes:
|
routes:
|
||||||
- 10.10.0.0/24
|
- 10.10.0.0/24
|
||||||
|
networks:
|
||||||
|
- 10.10.0.0/24
|
||||||
dnsServers:
|
dnsServers:
|
||||||
- servers:
|
- servers:
|
||||||
- 8.8.8.8:53
|
- 8.8.8.8:53
|
||||||
@@ -509,6 +527,7 @@ func TestParsingToDetail(t *testing.T) {
|
|||||||
Transfer status (received/sent) 200 B/100 B
|
Transfer status (received/sent) 200 B/100 B
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: 10.1.0.0/24
|
Routes: 10.1.0.0/24
|
||||||
|
Networks: 10.1.0.0/24
|
||||||
Latency: 10ms
|
Latency: 10ms
|
||||||
|
|
||||||
peer-2.awesome-domain.com:
|
peer-2.awesome-domain.com:
|
||||||
@@ -525,6 +544,7 @@ func TestParsingToDetail(t *testing.T) {
|
|||||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: -
|
Routes: -
|
||||||
|
Networks: -
|
||||||
Latency: 10ms
|
Latency: 10ms
|
||||||
|
|
||||||
OS: %s/%s
|
OS: %s/%s
|
||||||
@@ -543,6 +563,7 @@ NetBird IP: 192.168.178.100/16
|
|||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: 10.10.0.0/24
|
Routes: 10.10.0.0/24
|
||||||
|
Networks: 10.10.0.0/24
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||||
|
|
||||||
@@ -564,6 +585,7 @@ NetBird IP: 192.168.178.100/16
|
|||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Routes: 10.10.0.0/24
|
Routes: 10.10.0.0/24
|
||||||
|
Networks: 10.10.0.0/24
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|||||||
31
client/cmd/system.go
Normal file
31
client/cmd/system.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
// Flag constants for system configuration
|
||||||
|
const (
|
||||||
|
disableClientRoutesFlag = "disable-client-routes"
|
||||||
|
disableServerRoutesFlag = "disable-server-routes"
|
||||||
|
disableDNSFlag = "disable-dns"
|
||||||
|
disableFirewallFlag = "disable-firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
disableClientRoutes bool
|
||||||
|
disableServerRoutes bool
|
||||||
|
disableDNS bool
|
||||||
|
disableFirewall bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Add system flags to upCmd
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
|
||||||
|
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
|
||||||
|
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
|
||||||
|
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||||
|
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||||
|
}
|
||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
@@ -71,7 +73,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -93,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ func init() {
|
|||||||
)
|
)
|
||||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||||
}
|
}
|
||||||
|
|
||||||
func upFunc(cmd *cobra.Command, args []string) error {
|
func upFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -147,6 +148,23 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.DNSRouteInterval = &dnsRouteInterval
|
ic.DNSRouteInterval = &dnsRouteInterval
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||||
|
ic.DisableClientRoutes = &disableClientRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||||
|
ic.DisableServerRoutes = &disableServerRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableDNSFlag).Changed {
|
||||||
|
ic.DisableDNS = &disableDNS
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableFirewallFlag).Changed {
|
||||||
|
ic.DisableFirewall = &disableFirewall
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||||
|
ic.BlockLANAccess = &blockLANAccess
|
||||||
|
}
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
providedSetupKey, err := getSetupKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -264,6 +282,23 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||||
|
loginRequest.DisableClientRoutes = &disableClientRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||||
|
loginRequest.DisableServerRoutes = &disableServerRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableDNSFlag).Changed {
|
||||||
|
loginRequest.DisableDns = &disableDNS
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableFirewallFlag).Changed {
|
||||||
|
loginRequest.DisableFirewall = &disableFirewall
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||||
|
loginRequest.BlockLanAccess = &blockLANAccess
|
||||||
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
|
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
|
|||||||
24
client/configs/configs.go
Normal file
24
client/configs/configs.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package configs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
var StateDir string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
StateDir = os.Getenv("NB_STATE_DIR")
|
||||||
|
if StateDir != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
|
||||||
|
case "darwin", "linux":
|
||||||
|
StateDir = "/var/lib/netbird"
|
||||||
|
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||||
|
StateDir = "/var/db/netbird"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,7 @@ package iptables
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"slices"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -19,8 +19,7 @@ const (
|
|||||||
tableName = "filter"
|
tableName = "filter"
|
||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type aclEntries map[string][][]string
|
type aclEntries map[string][][]string
|
||||||
@@ -84,28 +83,22 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var dPortVal, sPortVal string
|
chain := chainNameInputRules
|
||||||
if dPort != nil && dPort.Values != nil {
|
|
||||||
// TODO: we support only one port per rule in current implementation of ACLs
|
|
||||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
|
||||||
}
|
|
||||||
if sPort != nil && sPort.Values != nil {
|
|
||||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var chain string
|
ipsetName = transformIPsetName(ipsetName, sPort, dPort)
|
||||||
if direction == firewall.RuleDirectionOUT {
|
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
||||||
chain = chainNameOutputRules
|
|
||||||
} else {
|
|
||||||
chain = chainNameInputRules
|
|
||||||
}
|
|
||||||
|
|
||||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
mangleSpecs := slices.Clone(specs)
|
||||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
mangleSpecs = append(mangleSpecs,
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||||
|
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||||
|
)
|
||||||
|
|
||||||
|
specs = append(specs, "-j", actionToStr(action))
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
@@ -137,7 +130,7 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
|
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -145,16 +138,22 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("rule already exists")
|
return nil, fmt.Errorf("rule already exists")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
|
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||||
|
log.Errorf("failed to add mangle rule: %v", err)
|
||||||
|
mangleSpecs = nil
|
||||||
|
}
|
||||||
|
|
||||||
rule := &Rule{
|
rule := &Rule{
|
||||||
ruleID: uuid.New().String(),
|
ruleID: uuid.New().String(),
|
||||||
specs: specs,
|
specs: specs,
|
||||||
ipsetName: ipsetName,
|
mangleSpecs: mangleSpecs,
|
||||||
ip: ip.String(),
|
ipsetName: ipsetName,
|
||||||
chain: chain,
|
ip: ip.String(),
|
||||||
|
chain: chain,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
@@ -197,6 +196,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.mangleSpecs != nil {
|
||||||
|
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -214,28 +219,7 @@ func (m *aclManager) Reset() error {
|
|||||||
|
|
||||||
// todo write less destructive cleanup mechanism
|
// todo write less destructive cleanup mechanism
|
||||||
func (m *aclManager) cleanChains() error {
|
func (m *aclManager) cleanChains() error {
|
||||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to list chains: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
rules := m.entries["OUTPUT"]
|
|
||||||
for _, rule := range rules {
|
|
||||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to list chains: %s", err)
|
log.Debugf("failed to list chains: %s", err)
|
||||||
return err
|
return err
|
||||||
@@ -295,12 +279,6 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// chain netbird-acl-output-rules
|
|
||||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
|
||||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for chainName, rules := range m.entries {
|
for chainName, rules := range m.entries {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||||
@@ -329,21 +307,13 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
|
|
||||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
|
|
||||||
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
|
||||||
func (m *aclManager) seedInitialEntries() {
|
func (m *aclManager) seedInitialEntries() {
|
||||||
|
|
||||||
established := getConntrackEstablished()
|
established := getConntrackEstablished()
|
||||||
|
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
|
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
|
|
||||||
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
||||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||||
@@ -352,17 +322,10 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
m.optionalEntries["FORWARD"] = []entry{
|
m.optionalEntries["FORWARD"] = []entry{
|
||||||
{
|
{
|
||||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||||
position: 2,
|
position: 2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m.optionalEntries["PREROUTING"] = []entry{
|
|
||||||
{
|
|
||||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
|
||||||
position: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||||
@@ -396,42 +359,26 @@ func (m *aclManager) updateState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func filterRuleSpecs(
|
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
|
||||||
) (specs []string) {
|
|
||||||
matchByIP := true
|
matchByIP := true
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// don't use IP matching if IP is ip 0.0.0.0
|
||||||
if ip.String() == "0.0.0.0" {
|
if ip.String() == "0.0.0.0" {
|
||||||
matchByIP = false
|
matchByIP = false
|
||||||
}
|
}
|
||||||
switch direction {
|
|
||||||
case firewall.RuleDirectionIN:
|
if matchByIP {
|
||||||
if matchByIP {
|
if ipsetName != "" {
|
||||||
if ipsetName != "" {
|
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
} else {
|
||||||
} else {
|
specs = append(specs, "-s", ip.String())
|
||||||
specs = append(specs, "-s", ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case firewall.RuleDirectionOUT:
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-d", ip.String())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if protocol != "all" {
|
if protocol != "all" {
|
||||||
specs = append(specs, "-p", protocol)
|
specs = append(specs, "-p", protocol)
|
||||||
}
|
}
|
||||||
if sPort != "" {
|
specs = append(specs, applyPort("--sport", sPort)...)
|
||||||
specs = append(specs, "--sport", sPort)
|
specs = append(specs, applyPort("--dport", dPort)...)
|
||||||
}
|
return specs
|
||||||
if dPort != "" {
|
|
||||||
specs = append(specs, "--dport", dPort)
|
|
||||||
}
|
|
||||||
return append(specs, "-j", actionToStr(action))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func actionToStr(action firewall.Action) string {
|
func actionToStr(action firewall.Action) string {
|
||||||
@@ -441,15 +388,15 @@ func actionToStr(action firewall.Action) string {
|
|||||||
return "DROP"
|
return "DROP"
|
||||||
}
|
}
|
||||||
|
|
||||||
func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
|
||||||
switch {
|
switch {
|
||||||
case ipsetName == "":
|
case ipsetName == "":
|
||||||
return ""
|
return ""
|
||||||
case sPort != "" && dPort != "":
|
case sPort != nil && dPort != nil:
|
||||||
return ipsetName + "-sport-dport"
|
return ipsetName + "-sport-dport"
|
||||||
case sPort != "":
|
case sPort != nil:
|
||||||
return ipsetName + "-sport"
|
return ipsetName + "-sport"
|
||||||
case dPort != "":
|
case dPort != nil:
|
||||||
return ipsetName + "-dport"
|
return ipsetName + "-dport"
|
||||||
default:
|
default:
|
||||||
return ipsetName
|
return ipsetName
|
||||||
|
|||||||
@@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
_ string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
@@ -197,29 +196,18 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
net.ParseIP("0.0.0.0"),
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.RuleDirectionIN,
|
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
_, err = m.AddPeerFiltering(
|
return nil
|
||||||
net.ParseIP("0.0.0.0"),
|
|
||||||
"all",
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
firewall.RuleDirectionOUT,
|
|
||||||
firewall.ActionAccept,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
|||||||
@@ -68,27 +68,14 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule1 []fw.Rule
|
|
||||||
t.Run("add first rule", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.2")
|
|
||||||
port := &fw.Port{Values: []int{8080}}
|
|
||||||
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
for _, r := range rule1 {
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []int{8043: 8046},
|
IsRange: true,
|
||||||
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(
|
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -97,15 +84,6 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
|
||||||
for _, r := range rule1 {
|
|
||||||
err := manager.DeletePeerRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
err := manager.DeletePeerRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
@@ -118,8 +96,8 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset(nil)
|
err = manager.Reset(nil)
|
||||||
@@ -135,9 +113,6 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManagerIPSet(t *testing.T) {
|
func TestIptablesManagerIPSet(t *testing.T) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
mock := &iFaceMock{
|
mock := &iFaceMock{
|
||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
@@ -167,33 +142,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule1 []fw.Rule
|
|
||||||
t.Run("add first rule with set", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.2")
|
|
||||||
port := &fw.Port{Values: []int{8080}}
|
|
||||||
rule1, err = manager.AddPeerFiltering(
|
|
||||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
|
||||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
for _, r := range rule1 {
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
|
||||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []int{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(
|
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
|
||||||
"default", "accept HTTPS traffic from ports range",
|
|
||||||
)
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -201,15 +156,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
|
||||||
for _, r := range rule1 {
|
|
||||||
err := manager.DeletePeerRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
err := manager.DeletePeerRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
@@ -269,12 +215,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
if i%2 == 0 {
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -590,10 +590,10 @@ func applyPort(flag string, port *firewall.Port) []string {
|
|||||||
if len(port.Values) > 1 {
|
if len(port.Values) > 1 {
|
||||||
portList := make([]string, len(port.Values))
|
portList := make([]string, len(port.Values))
|
||||||
for i, p := range port.Values {
|
for i, p := range port.Values {
|
||||||
portList[i] = strconv.Itoa(p)
|
portList[i] = strconv.Itoa(int(p))
|
||||||
}
|
}
|
||||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{flag, strconv.Itoa(port.Values[0])}
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -239,7 +239,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{80}},
|
dPort: &firewall.Port{Values: []uint16{80}},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -252,7 +252,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
@@ -285,7 +285,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
@@ -297,7 +297,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -307,8 +307,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||||
dPort: &firewall.Port{Values: []int{22}},
|
dPort: &firewall.Port{Values: []uint16{22}},
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ type Rule struct {
|
|||||||
ruleID string
|
ruleID string
|
||||||
ipsetName string
|
ipsetName string
|
||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
ip string
|
mangleSpecs []string
|
||||||
chain string
|
ip string
|
||||||
|
chain string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ type Manager interface {
|
|||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
dPort *Port,
|
dPort *Port,
|
||||||
direction RuleDirection,
|
|
||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ type Port struct {
|
|||||||
IsRange bool
|
IsRange bool
|
||||||
|
|
||||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||||
Values []int
|
Values []uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// String interface implementation
|
// String interface implementation
|
||||||
@@ -40,7 +40,11 @@ func (p *Port) String() string {
|
|||||||
if ports != "" {
|
if ports != "" {
|
||||||
ports += ","
|
ports += ","
|
||||||
}
|
}
|
||||||
ports += strconv.Itoa(port)
|
ports += strconv.Itoa(int(port))
|
||||||
}
|
}
|
||||||
|
if p.IsRange {
|
||||||
|
ports = "range:" + ports
|
||||||
|
}
|
||||||
|
|
||||||
return ports
|
return ports
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,9 @@ package nftables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -23,12 +22,10 @@ import (
|
|||||||
const (
|
const (
|
||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "netbird-acl-input-rules"
|
chainNameInputRules = "netbird-acl-input-rules"
|
||||||
chainNameOutputRules = "netbird-acl-output-rules"
|
|
||||||
|
|
||||||
// filter chains contains the rules that jump to the rules chains
|
// filter chains contains the rules that jump to the rules chains
|
||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
chainNameOutputFilter = "netbird-acl-output-filter"
|
|
||||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||||
chainNamePrerouting = "netbird-rt-prerouting"
|
chainNamePrerouting = "netbird-rt-prerouting"
|
||||||
|
|
||||||
@@ -47,9 +44,9 @@ type AclManager struct {
|
|||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
routingFwChainName string
|
||||||
|
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
chainInputRules *nftables.Chain
|
chainInputRules *nftables.Chain
|
||||||
chainOutputRules *nftables.Chain
|
chainPrerouting *nftables.Chain
|
||||||
|
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
rules map[string]*Rule
|
rules map[string]*Rule
|
||||||
@@ -91,7 +88,6 @@ func (m *AclManager) AddPeerFiltering(
|
|||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
@@ -106,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRules := make([]firewall.Rule, 0, 2)
|
newRules := make([]firewall.Rule, 0, 2)
|
||||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment)
|
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -123,23 +119,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r.nftSet == nil {
|
if r.nftSet == nil {
|
||||||
err := m.rConn.DelRule(r.nftRule)
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
}
|
}
|
||||||
|
if r.mangleRule != nil {
|
||||||
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.GetRuleID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
||||||
if !ok {
|
if !ok {
|
||||||
err := m.rConn.DelRule(r.nftRule)
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
}
|
}
|
||||||
|
if r.mangleRule != nil {
|
||||||
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.GetRuleID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := ips[r.ip.String()]; ok {
|
if _, ok := ips[r.ip.String()]; ok {
|
||||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -158,12 +163,16 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := m.rConn.DelRule(r.nftRule)
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
}
|
}
|
||||||
err = m.rConn.Flush()
|
if r.mangleRule != nil {
|
||||||
if err != nil {
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,38 +225,6 @@ func (m *AclManager) createDefaultAllowRules() error {
|
|||||||
Exprs: expIn,
|
Exprs: expIn,
|
||||||
})
|
})
|
||||||
|
|
||||||
expOut := []expr.Any{
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
// mask
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 1,
|
|
||||||
DestRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: []byte{0, 0, 0, 0},
|
|
||||||
Xor: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
// net address
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: m.chainOutputRules,
|
|
||||||
Position: 0,
|
|
||||||
Exprs: expOut,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
@@ -262,25 +239,33 @@ func (m *AclManager) Flush() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.chainInputRules); err != nil {
|
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
||||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||||
}
|
}
|
||||||
|
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
||||||
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
|
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
func (m *AclManager) addIOFiltering(
|
||||||
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
|
ip net.IP,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipset *nftables.Set,
|
||||||
|
comment string,
|
||||||
|
) (*Rule, error) {
|
||||||
|
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
if r, ok := m.rules[ruleId]; ok {
|
||||||
return &Rule{
|
return &Rule{
|
||||||
r.nftRule,
|
nftRule: r.nftRule,
|
||||||
r.nftSet,
|
mangleRule: r.mangleRule,
|
||||||
r.ruleID,
|
nftSet: r.nftSet,
|
||||||
ip,
|
ruleID: r.ruleID,
|
||||||
|
ip: ip,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,9 +297,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||||
// source address position
|
// source address position
|
||||||
addrOffset := uint32(12)
|
addrOffset := uint32(12)
|
||||||
if direction == firewall.RuleDirectionOUT {
|
|
||||||
addrOffset += 4 // is ipv4 address length
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
expressions = append(expressions,
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
@@ -344,65 +326,34 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) != 0 {
|
expressions = append(expressions, applyPort(sPort, true)...)
|
||||||
expressions = append(expressions,
|
expressions = append(expressions, applyPort(dPort, false)...)
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 0,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*sPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort != nil && len(dPort.Values) != 0 {
|
mainExpressions := slices.Clone(expressions)
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*dPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case firewall.ActionAccept:
|
case firewall.ActionAccept:
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||||
case firewall.ActionDrop:
|
case firewall.ActionDrop:
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
||||||
|
|
||||||
var chain *nftables.Chain
|
chain := m.chainInputRules
|
||||||
if direction == firewall.RuleDirectionIN {
|
|
||||||
chain = m.chainInputRules
|
|
||||||
} else {
|
|
||||||
chain = m.chainOutputRules
|
|
||||||
}
|
|
||||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: mainExpressions,
|
||||||
UserData: userData,
|
UserData: userData,
|
||||||
})
|
})
|
||||||
|
|
||||||
rule := &Rule{
|
rule := &Rule{
|
||||||
nftRule: nftRule,
|
nftRule: nftRule,
|
||||||
nftSet: ipset,
|
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||||
ruleID: ruleId,
|
nftSet: ipset,
|
||||||
ip: ip,
|
ruleID: ruleId,
|
||||||
|
ip: ip,
|
||||||
}
|
}
|
||||||
m.rules[ruleId] = rule
|
m.rules[ruleId] = rule
|
||||||
if ipset != nil {
|
if ipset != nil {
|
||||||
@@ -411,6 +362,59 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||||
|
if m.chainPrerouting == nil {
|
||||||
|
log.Warn("prerouting chain is not created")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
preroutingExprs := slices.Clone(expressions)
|
||||||
|
|
||||||
|
// interface
|
||||||
|
preroutingExprs = append([]expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}, preroutingExprs...)
|
||||||
|
|
||||||
|
// local destination and mark
|
||||||
|
preroutingExprs = append(preroutingExprs,
|
||||||
|
&expr.Fib{
|
||||||
|
Register: 1,
|
||||||
|
ResultADDRTYPE: true,
|
||||||
|
FlagDADDR: true,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||||
|
},
|
||||||
|
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: m.chainPrerouting,
|
||||||
|
Exprs: preroutingExprs,
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (m *AclManager) createDefaultChains() (err error) {
|
func (m *AclManager) createDefaultChains() (err error) {
|
||||||
// chainNameInputRules
|
// chainNameInputRules
|
||||||
chain := m.createChain(chainNameInputRules)
|
chain := m.createChain(chainNameInputRules)
|
||||||
@@ -421,15 +425,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
}
|
}
|
||||||
m.chainInputRules = chain
|
m.chainInputRules = chain
|
||||||
|
|
||||||
// chainNameOutputRules
|
|
||||||
chain = m.createChain(chainNameOutputRules)
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.chainOutputRules = chain
|
|
||||||
|
|
||||||
// netbird-acl-input-filter
|
// netbird-acl-input-filter
|
||||||
// type filter hook input priority filter; policy accept;
|
// type filter hook input priority filter; policy accept;
|
||||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||||
@@ -441,18 +436,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// netbird-acl-output-filter
|
|
||||||
// type filter hook output priority filter; policy accept;
|
|
||||||
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
|
|
||||||
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
|
|
||||||
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
|
|
||||||
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// netbird-acl-forward-filter
|
// netbird-acl-forward-filter
|
||||||
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||||
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||||
@@ -475,7 +458,7 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||||
// netbird peer IP.
|
// netbird peer IP.
|
||||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||||
preroutingChain := m.rConn.AddChain(&nftables.Chain{
|
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
||||||
Name: chainNamePrerouting,
|
Name: chainNamePrerouting,
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
@@ -483,8 +466,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
|||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
})
|
})
|
||||||
|
|
||||||
m.addPreroutingRule(preroutingChain)
|
|
||||||
|
|
||||||
m.addFwmarkToForward(chainFwFilter)
|
m.addFwmarkToForward(chainFwFilter)
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
@@ -494,43 +475,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
|
||||||
m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: preroutingChain,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyIIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Fib{
|
|
||||||
Register: 1,
|
|
||||||
ResultADDRTYPE: true,
|
|
||||||
FlagDADDR: true,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
SourceRegister: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||||
m.rConn.InsertRule(&nftables.Rule{
|
m.rConn.InsertRule(&nftables.Rule{
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
@@ -546,8 +490,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
|||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
},
|
},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictJump,
|
Kind: expr.VerdictAccept,
|
||||||
Chain: m.chainInputRules.Name,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -619,45 +562,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
|
||||||
dstOp := expr.CmpOpNeq
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: iifname, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: dstOp,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||||
expressions := []expr.Any{
|
expressions := []expr.Any{
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||||
@@ -733,6 +637,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
|||||||
for i := 0; ; i++ {
|
for i := 0; ; i++ {
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Debugf("failed to flush nftables: %v", err)
|
||||||
if !strings.Contains(err.Error(), "busy") {
|
if !strings.Contains(err.Error(), "busy") {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -749,7 +654,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||||
if m.workTable == nil || chain == nil {
|
if m.workTable == nil || chain == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -766,22 +671,19 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
|||||||
split := bytes.Split(rule.UserData, []byte(" "))
|
split := bytes.Split(rule.UserData, []byte(" "))
|
||||||
r, ok := m.rules[string(split[0])]
|
r, ok := m.rules[string(split[0])]
|
||||||
if ok {
|
if ok {
|
||||||
*r.nftRule = *rule
|
if mangle {
|
||||||
|
*r.mangleRule = *rule
|
||||||
|
} else {
|
||||||
|
*r.nftRule = *rule
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generatePeerRuleId(
|
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||||
ip net.IP,
|
rulesetID := ":"
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
|
||||||
ipset *nftables.Set,
|
|
||||||
) string {
|
|
||||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
|
||||||
if sPort != nil {
|
if sPort != nil {
|
||||||
rulesetID += sPort.String()
|
rulesetID += sPort.String()
|
||||||
}
|
}
|
||||||
@@ -797,12 +699,6 @@ func generatePeerRuleId(
|
|||||||
return "set:" + ipset.Name + rulesetID
|
return "set:" + ipset.Name + rulesetID
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodePort(port firewall.Port) []byte {
|
|
||||||
bs := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
|
||||||
return bs
|
|
||||||
}
|
|
||||||
|
|
||||||
func ifname(n string) []byte {
|
func ifname(n string) []byte {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
copy(b, n+"\x00")
|
copy(b, n+"\x00")
|
||||||
|
|||||||
@@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
@@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(
|
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
||||||
ip,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []int{53}},
|
|
||||||
fw.RuleDirectionIN,
|
|
||||||
fw.ActionDrop,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -209,12 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
if i%2 == 0 {
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := net.ParseIP("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(
|
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
||||||
ip,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []int{80}},
|
|
||||||
fw.RuleDirectionIN,
|
|
||||||
fw.ActionAccept,
|
|
||||||
"",
|
|
||||||
"test rule",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
@@ -313,7 +291,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
netip.MustParsePrefix("10.1.0.0/24"),
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
&fw.Port{Values: []int{443}},
|
&fw.Port{Values: []uint16{443}},
|
||||||
fw.ActionAccept,
|
fw.ActionAccept,
|
||||||
)
|
)
|
||||||
require.NoError(t, err, "failed to add route filtering rule")
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|||||||
@@ -956,12 +956,12 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpGte,
|
Op: expr.CmpOpGte,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
|
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||||
},
|
},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpLte,
|
Op: expr.CmpOpLte,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
|
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@@ -980,7 +980,7 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
exprs = append(exprs, &expr.Cmp{
|
exprs = append(exprs, &expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
|
Data: binaryutil.BigEndian.PutUint16(p),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{80}},
|
dPort: &firewall.Port{Values: []uint16{80}},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -235,7 +235,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
@@ -268,7 +268,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
@@ -280,7 +280,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -290,8 +290,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||||
dPort: &firewall.Port{Values: []int{22}},
|
dPort: &firewall.Port{Values: []uint16{22}},
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
|
|||||||
@@ -8,10 +8,11 @@ import (
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
nftRule *nftables.Rule
|
nftRule *nftables.Rule
|
||||||
nftSet *nftables.Set
|
mangleRule *nftables.Rule
|
||||||
ruleID string
|
nftSet *nftables.Set
|
||||||
ip net.IP
|
ruleID string
|
||||||
|
ip net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
|
|||||||
@@ -2,7 +2,10 @@
|
|||||||
|
|
||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/statemanager"
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||||
@@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Reset(stateManager)
|
return m.nativeFirewall.Reset(stateManager)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
137
client/firewall/uspfilter/conntrack/common.go
Normal file
137
client/firewall/uspfilter/conntrack/common.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
// common.go
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BaseConnTrack provides common fields and locking for all connection types
|
||||||
|
type BaseConnTrack struct {
|
||||||
|
SourceIP net.IP
|
||||||
|
DestIP net.IP
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
|
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||||
|
established atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// these small methods will be inlined by the compiler
|
||||||
|
|
||||||
|
// UpdateLastSeen safely updates the last seen timestamp
|
||||||
|
func (b *BaseConnTrack) UpdateLastSeen() {
|
||||||
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEstablished safely checks if connection is established
|
||||||
|
func (b *BaseConnTrack) IsEstablished() bool {
|
||||||
|
return b.established.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEstablished safely sets the established state
|
||||||
|
func (b *BaseConnTrack) SetEstablished(state bool) {
|
||||||
|
b.established.Store(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeoutExceeded checks if the connection has exceeded the given timeout
|
||||||
|
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
||||||
|
lastSeen := time.Unix(0, b.lastSeen.Load())
|
||||||
|
return time.Since(lastSeen) > timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPAddr is a fixed-size IP address to avoid allocations
|
||||||
|
type IPAddr [16]byte
|
||||||
|
|
||||||
|
// MakeIPAddr creates an IPAddr from net.IP
|
||||||
|
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
||||||
|
// Optimization: check for v4 first as it's more common
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
copy(addr[12:], ip4)
|
||||||
|
} else {
|
||||||
|
copy(addr[:], ip.To16())
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnKey uniquely identifies a connection
|
||||||
|
type ConnKey struct {
|
||||||
|
SrcIP IPAddr
|
||||||
|
DstIP IPAddr
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeConnKey creates a connection key
|
||||||
|
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
||||||
|
return ConnKey{
|
||||||
|
SrcIP: MakeIPAddr(srcIP),
|
||||||
|
DstIP: MakeIPAddr(dstIP),
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateIPs checks if IPs match without allocation
|
||||||
|
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
||||||
|
if ip4 := pktIP.To4(); ip4 != nil {
|
||||||
|
// Compare IPv4 addresses (last 4 bytes)
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
if connIP[12+i] != ip4[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Compare full IPv6 addresses
|
||||||
|
ip6 := pktIP.To16()
|
||||||
|
for i := 0; i < 16; i++ {
|
||||||
|
if connIP[i] != ip6[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
||||||
|
type PreallocatedIPs struct {
|
||||||
|
sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPreallocatedIPs creates a new IP pool
|
||||||
|
func NewPreallocatedIPs() *PreallocatedIPs {
|
||||||
|
return &PreallocatedIPs{
|
||||||
|
Pool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
ip := make(net.IP, 16)
|
||||||
|
return &ip
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an IP from the pool
|
||||||
|
func (p *PreallocatedIPs) Get() net.IP {
|
||||||
|
return *p.Pool.Get().(*net.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put returns an IP to the pool
|
||||||
|
func (p *PreallocatedIPs) Put(ip net.IP) {
|
||||||
|
p.Pool.Put(&ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyIP copies an IP address efficiently
|
||||||
|
func copyIP(dst, src net.IP) {
|
||||||
|
if len(src) == 16 {
|
||||||
|
copy(dst, src)
|
||||||
|
} else {
|
||||||
|
// Handle IPv4
|
||||||
|
copy(dst[12:], src.To4())
|
||||||
|
}
|
||||||
|
}
|
||||||
115
client/firewall/uspfilter/conntrack/common_test.go
Normal file
115
client/firewall/uspfilter/conntrack/common_test.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkIPOperations(b *testing.B) {
|
||||||
|
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = MakeIPAddr(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ValidateIPs", func(b *testing.B) {
|
||||||
|
ip1 := net.ParseIP("192.168.1.1")
|
||||||
|
ip2 := net.ParseIP("192.168.1.1")
|
||||||
|
addr := MakeIPAddr(ip1)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = ValidateIPs(addr, ip2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IPPool", func(b *testing.B) {
|
||||||
|
pool := NewPreallocatedIPs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ip := pool.Get()
|
||||||
|
pool.Put(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
func BenchmarkAtomicOperations(b *testing.B) {
|
||||||
|
conn := &BaseConnTrack{}
|
||||||
|
b.Run("UpdateLastSeen", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsEstablished", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = conn.IsEstablished()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("SetEstablished", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conn.SetEstablished(i%2 == 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("GetLastSeen", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = conn.GetLastSeen()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Memory pressure tests
|
||||||
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Generate different IPs
|
||||||
|
srcIPs := make([]net.IP, 100)
|
||||||
|
dstIPs := make([]net.IP, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||||
|
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
srcIdx := i % len(srcIPs)
|
||||||
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
||||||
|
|
||||||
|
// Simulate some valid inbound packets
|
||||||
|
if i%3 == 0 {
|
||||||
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Generate different IPs
|
||||||
|
srcIPs := make([]net.IP, 100)
|
||||||
|
dstIPs := make([]net.IP, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||||
|
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
srcIdx := i % len(srcIPs)
|
||||||
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
||||||
|
|
||||||
|
// Simulate some valid inbound packets
|
||||||
|
if i%3 == 0 {
|
||||||
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
170
client/firewall/uspfilter/conntrack/icmp.go
Normal file
170
client/firewall/uspfilter/conntrack/icmp.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultICMPTimeout is the default timeout for ICMP connections
|
||||||
|
DefaultICMPTimeout = 30 * time.Second
|
||||||
|
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||||
|
ICMPCleanupInterval = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
|
type ICMPConnKey struct {
|
||||||
|
// Supports both IPv4 and IPv6
|
||||||
|
SrcIP [16]byte
|
||||||
|
DstIP [16]byte
|
||||||
|
Sequence uint16 // ICMP sequence number
|
||||||
|
ID uint16 // ICMP identifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
|
type ICMPConnTrack struct {
|
||||||
|
BaseConnTrack
|
||||||
|
Sequence uint16
|
||||||
|
ID uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICMPTracker manages ICMP connection states
|
||||||
|
type ICMPTracker struct {
|
||||||
|
connections map[ICMPConnKey]*ICMPConnTrack
|
||||||
|
timeout time.Duration
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
mutex sync.RWMutex
|
||||||
|
done chan struct{}
|
||||||
|
ipPool *PreallocatedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
|
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultICMPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker := &ICMPTracker{
|
||||||
|
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||||
|
timeout: timeout,
|
||||||
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.cleanupRoutine()
|
||||||
|
return tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound ICMP Echo Request
|
||||||
|
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||||
|
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
if !exists {
|
||||||
|
srcIPCopy := t.ipPool.Get()
|
||||||
|
dstIPCopy := t.ipPool.Get()
|
||||||
|
copyIP(srcIPCopy, srcIP)
|
||||||
|
copyIP(dstIPCopy, dstIP)
|
||||||
|
|
||||||
|
conn = &ICMPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
SourceIP: srcIPCopy,
|
||||||
|
DestIP: dstIPCopy,
|
||||||
|
},
|
||||||
|
ID: id,
|
||||||
|
Sequence: seq,
|
||||||
|
}
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
conn.established.Store(true)
|
||||||
|
t.connections[key] = conn
|
||||||
|
}
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
|
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||||
|
switch icmpType {
|
||||||
|
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
||||||
|
uint8(layers.ICMPv4TypeTimeExceeded):
|
||||||
|
return true
|
||||||
|
case uint8(layers.ICMPv4TypeEchoReply):
|
||||||
|
// continue processing
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := makeICMPKey(dstIP, srcIP, id, seq)
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.IsEstablished() &&
|
||||||
|
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
|
conn.ID == id &&
|
||||||
|
conn.Sequence == seq
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) cleanupRoutine() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.cleanupTicker.C:
|
||||||
|
t.cleanup()
|
||||||
|
case <-t.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (t *ICMPTracker) cleanup() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, conn := range t.connections {
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
delete(t.connections, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine and releases resources
|
||||||
|
func (t *ICMPTracker) Close() {
|
||||||
|
t.cleanupTicker.Stop()
|
||||||
|
close(t.done)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
for _, conn := range t.connections {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
}
|
||||||
|
t.connections = nil
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeICMPKey creates an ICMP connection key
|
||||||
|
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
||||||
|
return ICMPConnKey{
|
||||||
|
SrcIP: MakeIPAddr(srcIP),
|
||||||
|
DstIP: MakeIPAddr(dstIP),
|
||||||
|
ID: id,
|
||||||
|
Sequence: seq,
|
||||||
|
}
|
||||||
|
}
|
||||||
39
client/firewall/uspfilter/conntrack/icmp_test.go
Normal file
39
client/firewall/uspfilter/conntrack/icmp_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
352
client/firewall/uspfilter/conntrack/tcp.go
Normal file
352
client/firewall/uspfilter/conntrack/tcp.go
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
// TODO: Send RST packets for invalid/timed-out connections
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MSL (Maximum Segment Lifetime) is typically 2 minutes
|
||||||
|
MSL = 2 * time.Minute
|
||||||
|
// TimeWaitTimeout (TIME-WAIT) should last 2*MSL
|
||||||
|
TimeWaitTimeout = 2 * MSL
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TCPSyn uint8 = 0x02
|
||||||
|
TCPAck uint8 = 0x10
|
||||||
|
TCPFin uint8 = 0x01
|
||||||
|
TCPRst uint8 = 0x04
|
||||||
|
TCPPush uint8 = 0x08
|
||||||
|
TCPUrg uint8 = 0x20
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultTCPTimeout is the default timeout for established TCP connections
|
||||||
|
DefaultTCPTimeout = 3 * time.Hour
|
||||||
|
// TCPHandshakeTimeout is timeout for TCP handshake completion
|
||||||
|
TCPHandshakeTimeout = 60 * time.Second
|
||||||
|
// TCPCleanupInterval is how often we check for stale connections
|
||||||
|
TCPCleanupInterval = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPState represents the state of a TCP connection
|
||||||
|
type TCPState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TCPStateNew TCPState = iota
|
||||||
|
TCPStateSynSent
|
||||||
|
TCPStateSynReceived
|
||||||
|
TCPStateEstablished
|
||||||
|
TCPStateFinWait1
|
||||||
|
TCPStateFinWait2
|
||||||
|
TCPStateClosing
|
||||||
|
TCPStateTimeWait
|
||||||
|
TCPStateCloseWait
|
||||||
|
TCPStateLastAck
|
||||||
|
TCPStateClosed
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPConnKey uniquely identifies a TCP connection
|
||||||
|
type TCPConnKey struct {
|
||||||
|
SrcIP [16]byte
|
||||||
|
DstIP [16]byte
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPConnTrack represents a TCP connection state
|
||||||
|
type TCPConnTrack struct {
|
||||||
|
BaseConnTrack
|
||||||
|
State TCPState
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPTracker manages TCP connection states
|
||||||
|
type TCPTracker struct {
|
||||||
|
connections map[ConnKey]*TCPConnTrack
|
||||||
|
mutex sync.RWMutex
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
done chan struct{}
|
||||||
|
timeout time.Duration
|
||||||
|
ipPool *PreallocatedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
|
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||||
|
tracker := &TCPTracker{
|
||||||
|
connections: make(map[ConnKey]*TCPConnTrack),
|
||||||
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
timeout: timeout,
|
||||||
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.cleanupRoutine()
|
||||||
|
return tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound processes an outbound TCP packet and updates connection state
|
||||||
|
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||||
|
// Create key before lock
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
if !exists {
|
||||||
|
// Use preallocated IPs
|
||||||
|
srcIPCopy := t.ipPool.Get()
|
||||||
|
dstIPCopy := t.ipPool.Get()
|
||||||
|
copyIP(srcIPCopy, srcIP)
|
||||||
|
copyIP(dstIPCopy, dstIP)
|
||||||
|
|
||||||
|
conn = &TCPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
SourceIP: srcIPCopy,
|
||||||
|
DestIP: dstIPCopy,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
},
|
||||||
|
State: TCPStateNew,
|
||||||
|
}
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
conn.established.Store(false)
|
||||||
|
t.connections[key] = conn
|
||||||
|
}
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
// Lock individual connection for state update
|
||||||
|
conn.Lock()
|
||||||
|
t.updateState(conn, flags, true)
|
||||||
|
conn.Unlock()
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
|
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||||
|
if !isValidFlagCombination(flags) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle RST packets
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
conn.Lock()
|
||||||
|
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetEstablished(false)
|
||||||
|
conn.Unlock()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
conn.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Lock()
|
||||||
|
t.updateState(conn, flags, false)
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
isEstablished := conn.IsEstablished()
|
||||||
|
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||||
|
conn.Unlock()
|
||||||
|
|
||||||
|
return isEstablished || isValidState
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateState updates the TCP connection state based on flags
|
||||||
|
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||||
|
// Handle RST flag specially - it always causes transition to closed
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetEstablished(false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch conn.State {
|
||||||
|
case TCPStateNew:
|
||||||
|
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||||
|
conn.State = TCPStateSynSent
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateSynSent:
|
||||||
|
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||||
|
if isOutbound {
|
||||||
|
conn.State = TCPStateSynReceived
|
||||||
|
} else {
|
||||||
|
// Simultaneous open
|
||||||
|
conn.State = TCPStateEstablished
|
||||||
|
conn.SetEstablished(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||||
|
conn.State = TCPStateEstablished
|
||||||
|
conn.SetEstablished(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateEstablished:
|
||||||
|
if flags&TCPFin != 0 {
|
||||||
|
if isOutbound {
|
||||||
|
conn.State = TCPStateFinWait1
|
||||||
|
} else {
|
||||||
|
conn.State = TCPStateCloseWait
|
||||||
|
}
|
||||||
|
conn.SetEstablished(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
switch {
|
||||||
|
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||||
|
// Simultaneous close - both sides sent FIN
|
||||||
|
conn.State = TCPStateClosing
|
||||||
|
case flags&TCPFin != 0:
|
||||||
|
conn.State = TCPStateFinWait2
|
||||||
|
case flags&TCPAck != 0:
|
||||||
|
conn.State = TCPStateFinWait2
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
if flags&TCPFin != 0 {
|
||||||
|
conn.State = TCPStateTimeWait
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateClosing:
|
||||||
|
if flags&TCPAck != 0 {
|
||||||
|
conn.State = TCPStateTimeWait
|
||||||
|
// Keep established = false from previous state
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
if flags&TCPFin != 0 {
|
||||||
|
conn.State = TCPStateLastAck
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateLastAck:
|
||||||
|
if flags&TCPAck != 0 {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||||
|
// This is handled by the cleanup routine
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
|
||||||
|
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||||
|
if !isValidFlagCombination(flags) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case TCPStateNew:
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||||
|
case TCPStateSynSent:
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateEstablished:
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
|
case TCPStateClosing:
|
||||||
|
// In CLOSING state, we should accept the final ACK
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
// In TIME_WAIT, we might see retransmissions
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
|
case TCPStateLastAck:
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateClosed:
|
||||||
|
// Accept retransmitted ACKs in closed state
|
||||||
|
// This is important because the final ACK might be lost
|
||||||
|
// and the peer will retransmit their FIN-ACK
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) cleanupRoutine() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.cleanupTicker.C:
|
||||||
|
t.cleanup()
|
||||||
|
case <-t.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) cleanup() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, conn := range t.connections {
|
||||||
|
var timeout time.Duration
|
||||||
|
switch {
|
||||||
|
case conn.State == TCPStateTimeWait:
|
||||||
|
timeout = TimeWaitTimeout
|
||||||
|
case conn.IsEstablished():
|
||||||
|
timeout = t.timeout
|
||||||
|
default:
|
||||||
|
timeout = TCPHandshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
lastSeen := conn.GetLastSeen()
|
||||||
|
if time.Since(lastSeen) > timeout {
|
||||||
|
// Return IPs to pool
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
delete(t.connections, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine and releases resources
|
||||||
|
func (t *TCPTracker) Close() {
|
||||||
|
t.cleanupTicker.Stop()
|
||||||
|
close(t.done)
|
||||||
|
|
||||||
|
// Clean up all remaining IPs
|
||||||
|
t.mutex.Lock()
|
||||||
|
for _, conn := range t.connections {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
}
|
||||||
|
t.connections = nil
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidFlagCombination(flags uint8) bool {
|
||||||
|
// Invalid: SYN+FIN
|
||||||
|
if flags&TCPSyn != 0 && flags&TCPFin != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid: RST with SYN or FIN
|
||||||
|
if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
308
client/firewall/uspfilter/conntrack/tcp_test.go
Normal file
308
client/firewall/uspfilter/conntrack/tcp_test.go
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
|
dstIP := net.ParseIP("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Security Tests", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
flags uint8
|
||||||
|
wantDrop bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Block unsolicited SYN-ACK",
|
||||||
|
flags: TCPSyn | TCPAck,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block SYN-ACK without prior SYN",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block invalid SYN-FIN",
|
||||||
|
flags: TCPSyn | TCPFin,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block invalid SYN-FIN combination",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block unsolicited RST",
|
||||||
|
flags: TCPRst,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block RST without connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block unsolicited ACK",
|
||||||
|
flags: TCPAck,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block ACK without connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block data without connection",
|
||||||
|
flags: TCPAck | TCPPush,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block data without established connection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
||||||
|
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Connection Flow Tests", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
test func(*testing.T)
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Normal Handshake",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Send initial SYN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||||
|
|
||||||
|
// Receive SYN-ACK
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||||
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
|
// Send ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
|
||||||
|
// Test data transfer
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||||
|
require.True(t, valid, "Data should be allowed after handshake")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Normal Close",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// First establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Send FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||||
|
|
||||||
|
// Receive ACK for FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||||
|
require.True(t, valid, "ACK for FIN should be allowed")
|
||||||
|
|
||||||
|
// Receive FIN from other side
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||||
|
require.True(t, valid, "FIN should be allowed")
|
||||||
|
|
||||||
|
// Send final ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RST During Connection",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// First establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Receive RST
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||||
|
require.True(t, valid, "RST should be allowed for established connection")
|
||||||
|
|
||||||
|
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||||
|
// The connection will be cleaned up by timeout
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simultaneous Close",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// First establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Both sides send FIN+ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||||
|
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||||
|
|
||||||
|
// Both sides send final ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||||
|
require.True(t, valid, "Final ACKs should be allowed")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
tt.test(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRSTHandling(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
|
dstIP := net.ParseIP("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupState func()
|
||||||
|
sendRST func()
|
||||||
|
wantValid bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RST in established",
|
||||||
|
setupState: func() {
|
||||||
|
// Establish connection first
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
},
|
||||||
|
sendRST: func() {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||||
|
},
|
||||||
|
wantValid: true,
|
||||||
|
desc: "Should accept RST for established connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RST without connection",
|
||||||
|
setupState: func() {},
|
||||||
|
sendRST: func() {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||||
|
},
|
||||||
|
wantValid: false,
|
||||||
|
desc: "Should reject RST without connection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setupState()
|
||||||
|
tt.sendRST()
|
||||||
|
|
||||||
|
// Verify connection state is as expected
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
if tt.wantValid {
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateClosed, conn.State)
|
||||||
|
require.False(t, conn.IsEstablished())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to establish a TCP connection
|
||||||
|
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||||
|
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||||
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
if i%2 == 0 {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||||
|
} else {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark connection cleanup
|
||||||
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Pre-populate with expired connections
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for connections to expire
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.cleanup()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
158
client/firewall/uspfilter/conntrack/udp.go
Normal file
158
client/firewall/uspfilter/conntrack/udp.go
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultUDPTimeout is the default timeout for UDP connections
|
||||||
|
DefaultUDPTimeout = 30 * time.Second
|
||||||
|
// UDPCleanupInterval is how often we check for stale connections
|
||||||
|
UDPCleanupInterval = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// UDPConnTrack represents a UDP connection state
|
||||||
|
type UDPConnTrack struct {
|
||||||
|
BaseConnTrack
|
||||||
|
}
|
||||||
|
|
||||||
|
// UDPTracker manages UDP connection states
|
||||||
|
type UDPTracker struct {
|
||||||
|
connections map[ConnKey]*UDPConnTrack
|
||||||
|
timeout time.Duration
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
mutex sync.RWMutex
|
||||||
|
done chan struct{}
|
||||||
|
ipPool *PreallocatedIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
|
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultUDPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker := &UDPTracker{
|
||||||
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
|
timeout: timeout,
|
||||||
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.cleanupRoutine()
|
||||||
|
return tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound UDP connection
|
||||||
|
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
if !exists {
|
||||||
|
srcIPCopy := t.ipPool.Get()
|
||||||
|
dstIPCopy := t.ipPool.Get()
|
||||||
|
copyIP(srcIPCopy, srcIP)
|
||||||
|
copyIP(dstIPCopy, dstIP)
|
||||||
|
|
||||||
|
conn = &UDPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
SourceIP: srcIPCopy,
|
||||||
|
DestIP: dstIPCopy,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
conn.established.Store(true)
|
||||||
|
t.connections[key] = conn
|
||||||
|
}
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
conn.lastSeen.Store(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
|
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
||||||
|
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.IsEstablished() &&
|
||||||
|
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||||
|
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||||
|
conn.DestPort == srcPort &&
|
||||||
|
conn.SourcePort == dstPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupRoutine periodically removes stale connections
|
||||||
|
func (t *UDPTracker) cleanupRoutine() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.cleanupTicker.C:
|
||||||
|
t.cleanup()
|
||||||
|
case <-t.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) cleanup() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, conn := range t.connections {
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
delete(t.connections, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine and releases resources
|
||||||
|
func (t *UDPTracker) Close() {
|
||||||
|
t.cleanupTicker.Stop()
|
||||||
|
close(t.done)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
for _, conn := range t.connections {
|
||||||
|
t.ipPool.Put(conn.SourceIP)
|
||||||
|
t.ipPool.Put(conn.DestIP)
|
||||||
|
}
|
||||||
|
t.connections = nil
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnection safely retrieves a connection state
|
||||||
|
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
||||||
|
t.mutex.RLock()
|
||||||
|
defer t.mutex.RUnlock()
|
||||||
|
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout returns the configured timeout duration for the tracker
|
||||||
|
func (t *UDPTracker) Timeout() time.Duration {
|
||||||
|
return t.timeout
|
||||||
|
}
|
||||||
243
client/firewall/uspfilter/conntrack/udp_test.go
Normal file
243
client/firewall/uspfilter/conntrack/udp_test.go
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewUDPTracker(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
timeout time.Duration
|
||||||
|
wantTimeout time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with custom timeout",
|
||||||
|
timeout: 1 * time.Minute,
|
||||||
|
wantTimeout: 1 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with zero timeout uses default",
|
||||||
|
timeout: 0,
|
||||||
|
wantTimeout: DefaultUDPTimeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tracker := NewUDPTracker(tt.timeout)
|
||||||
|
assert.NotNil(t, tracker)
|
||||||
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
|
assert.NotNil(t, tracker.connections)
|
||||||
|
assert.NotNil(t, tracker.cleanupTicker)
|
||||||
|
assert.NotNil(t, tracker.done)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
|
dstIP := net.ParseIP("192.168.1.3")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(53)
|
||||||
|
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Verify connection was tracked
|
||||||
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
conn, exists := tracker.connections[key]
|
||||||
|
require.True(t, exists)
|
||||||
|
assert.True(t, conn.SourceIP.Equal(srcIP))
|
||||||
|
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||||
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
|
assert.True(t, conn.IsEstablished())
|
||||||
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
|
tracker := NewUDPTracker(1 * time.Second)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
|
dstIP := net.ParseIP("192.168.1.3")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(53)
|
||||||
|
|
||||||
|
// Track outbound connection
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
srcIP net.IP
|
||||||
|
dstIP net.IP
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
sleep time.Duration
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid inbound response",
|
||||||
|
srcIP: dstIP, // Original destination is now source
|
||||||
|
dstIP: srcIP, // Original source is now destination
|
||||||
|
srcPort: dstPort, // Original destination port is now source
|
||||||
|
dstPort: srcPort, // Original source port is now destination
|
||||||
|
sleep: 0,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid source IP",
|
||||||
|
srcIP: net.ParseIP("192.168.1.4"),
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid destination IP",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: net.ParseIP("192.168.1.4"),
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid source port",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: 54321,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid destination port",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: 54321,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired connection",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 2 * time.Second, // Longer than tracker timeout
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.sleep > 0 {
|
||||||
|
time.Sleep(tt.sleep)
|
||||||
|
}
|
||||||
|
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPTracker_Cleanup(t *testing.T) {
|
||||||
|
// Use shorter intervals for testing
|
||||||
|
timeout := 50 * time.Millisecond
|
||||||
|
cleanupInterval := 25 * time.Millisecond
|
||||||
|
|
||||||
|
// Create tracker with custom cleanup interval
|
||||||
|
tracker := &UDPTracker{
|
||||||
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
|
timeout: timeout,
|
||||||
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
ipPool: NewPreallocatedIPs(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start cleanup routine
|
||||||
|
go tracker.cleanupRoutine()
|
||||||
|
|
||||||
|
// Add some connections
|
||||||
|
connections := []struct {
|
||||||
|
srcIP net.IP
|
||||||
|
dstIP net.IP
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
srcIP: net.ParseIP("192.168.1.2"),
|
||||||
|
dstIP: net.ParseIP("192.168.1.3"),
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
srcIP: net.ParseIP("192.168.1.4"),
|
||||||
|
dstIP: net.ParseIP("192.168.1.5"),
|
||||||
|
srcPort: 12346,
|
||||||
|
dstPort: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, conn := range connections {
|
||||||
|
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify initial connections
|
||||||
|
assert.Len(t, tracker.connections, 2)
|
||||||
|
|
||||||
|
// Wait for connection timeout and cleanup interval
|
||||||
|
time.Sleep(timeout + 2*cleanupInterval)
|
||||||
|
|
||||||
|
tracker.mutex.RLock()
|
||||||
|
connCount := len(tracker.connections)
|
||||||
|
tracker.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Verify connections were cleaned up
|
||||||
|
assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up")
|
||||||
|
|
||||||
|
// Properly close the tracker
|
||||||
|
tracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
dstIP := net.ParseIP("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -15,9 +15,8 @@ type Rule struct {
|
|||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
matchByIP bool
|
matchByIP bool
|
||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
direction firewall.RuleDirection
|
sPort *firewall.Port
|
||||||
sPort uint16
|
dPort *firewall.Port
|
||||||
dPort uint16
|
|
||||||
drop bool
|
drop bool
|
||||||
comment string
|
comment string
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -12,6 +14,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -19,6 +22,8 @@ import (
|
|||||||
|
|
||||||
const layerTypeAll = 0
|
const layerTypeAll = 0
|
||||||
|
|
||||||
|
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||||
)
|
)
|
||||||
@@ -34,7 +39,9 @@ type RuleSet map[string]Rule
|
|||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
outgoingRules map[string]RuleSet
|
// outgoingRules is used for hooks only
|
||||||
|
outgoingRules map[string]RuleSet
|
||||||
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[string]RuleSet
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
@@ -42,6 +49,11 @@ type Manager struct {
|
|||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
stateful bool
|
||||||
|
udpTracker *conntrack.UDPTracker
|
||||||
|
icmpTracker *conntrack.ICMPTracker
|
||||||
|
tcpTracker *conntrack.TCPTracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -73,6 +85,8 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
|
|||||||
}
|
}
|
||||||
|
|
||||||
func create(iface IFaceMapper) (*Manager, error) {
|
func create(iface IFaceMapper) (*Manager, error) {
|
||||||
|
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
@@ -90,6 +104,16 @@ func create(iface IFaceMapper) (*Manager, error) {
|
|||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[string]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[string]RuleSet),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
|
stateful: !disableConntrack,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only initialize trackers if stateful mode is enabled
|
||||||
|
if disableConntrack {
|
||||||
|
log.Info("conntrack is disabled")
|
||||||
|
} else {
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
@@ -134,9 +158,8 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
_ string,
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
r := Rule{
|
r := Rule{
|
||||||
@@ -144,7 +167,6 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
ip: ip,
|
ip: ip,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
matchByIP: true,
|
matchByIP: true,
|
||||||
direction: direction,
|
|
||||||
drop: action == firewall.ActionDrop,
|
drop: action == firewall.ActionDrop,
|
||||||
comment: comment,
|
comment: comment,
|
||||||
}
|
}
|
||||||
@@ -157,13 +179,8 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
r.matchByIP = false
|
r.matchByIP = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) == 1 {
|
r.sPort = sPort
|
||||||
r.sPort = uint16(sPort.Values[0])
|
r.dPort = dPort
|
||||||
}
|
|
||||||
|
|
||||||
if dPort != nil && len(dPort.Values) == 1 {
|
|
||||||
r.dPort = uint16(dPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
switch proto {
|
switch proto {
|
||||||
case firewall.ProtocolTCP:
|
case firewall.ProtocolTCP:
|
||||||
@@ -180,17 +197,10 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if direction == firewall.RuleDirectionIN {
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
|
||||||
}
|
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
|
||||||
} else {
|
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
|
||||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
|
||||||
}
|
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
|
||||||
}
|
}
|
||||||
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
@@ -219,19 +229,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.direction == firewall.RuleDirectionIN {
|
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.incomingRules[r.ip.String()], r.id)
|
|
||||||
} else {
|
|
||||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
|
||||||
}
|
}
|
||||||
|
delete(m.incomingRules[r.ip.String()], r.id)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -249,16 +250,16 @@ func (m *Manager) Flush() error { return nil }
|
|||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
return m.processOutgoingHooks(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// DropIncoming filter incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.incomingRules, true)
|
return m.dropFilter(packetData, m.incomingRules)
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements same logic for booth direction of the traffic
|
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
@@ -266,64 +267,235 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
|
|||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
log.Tracef("couldn't decode layer, err: %s", err)
|
return false
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
if len(d.decoded) < 2 {
|
||||||
log.Tracef("not enough levels in network packet")
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
|
if srcIP == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always process UDP hooks
|
||||||
|
if d.decoded[1] == layers.LayerTypeUDP {
|
||||||
|
// Track UDP state only if enabled
|
||||||
|
if m.stateful {
|
||||||
|
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
return m.checkUDPHooks(d, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track other protocols only if stateful mode is enabled
|
||||||
|
if m.stateful {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.trackICMPOutbound(d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
||||||
|
switch d.decoded[0] {
|
||||||
|
case layers.LayerTypeIPv4:
|
||||||
|
return d.ip4.SrcIP, d.ip4.DstIP
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
return d.ip6.SrcIP, d.ip6.DstIP
|
||||||
|
default:
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||||
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackOutbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
uint16(d.tcp.SrcPort),
|
||||||
|
uint16(d.tcp.DstPort),
|
||||||
|
flags,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||||
|
var flags uint8
|
||||||
|
if tcp.SYN {
|
||||||
|
flags |= conntrack.TCPSyn
|
||||||
|
}
|
||||||
|
if tcp.ACK {
|
||||||
|
flags |= conntrack.TCPAck
|
||||||
|
}
|
||||||
|
if tcp.FIN {
|
||||||
|
flags |= conntrack.TCPFin
|
||||||
|
}
|
||||||
|
if tcp.RST {
|
||||||
|
flags |= conntrack.TCPRst
|
||||||
|
}
|
||||||
|
if tcp.PSH {
|
||||||
|
flags |= conntrack.TCPPush
|
||||||
|
}
|
||||||
|
if tcp.URG {
|
||||||
|
flags |= conntrack.TCPUrg
|
||||||
|
}
|
||||||
|
return flags
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||||
|
m.udpTracker.TrackOutbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
uint16(d.udp.SrcPort),
|
||||||
|
uint16(d.udp.DstPort),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
||||||
|
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||||
|
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||||
|
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
||||||
|
m.icmpTracker.TrackOutbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
d.icmp4.Id,
|
||||||
|
d.icmp4.Seq,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dropFilter implements filtering logic for incoming packets
|
||||||
|
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||||
|
// TODO: Disable router if --disable-server-router is set
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
d := m.decoders.Get().(*decoder)
|
||||||
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
|
if !m.isValidPacket(d, packetData) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
ipLayer := d.decoded[0]
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
|
if srcIP == nil {
|
||||||
switch ipLayer {
|
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
var ip net.IP
|
if !m.isWireguardTraffic(srcIP, dstIP) {
|
||||||
switch ipLayer {
|
return false
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if isIncomingPacket {
|
|
||||||
ip = d.ip4.SrcIP
|
|
||||||
} else {
|
|
||||||
ip = d.ip4.DstIP
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if isIncomingPacket {
|
|
||||||
ip = d.ip6.SrcIP
|
|
||||||
} else {
|
|
||||||
ip = d.ip6.DstIP
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
// Check connection state only if enabled
|
||||||
if ok {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||||
return filter
|
return false
|
||||||
}
|
|
||||||
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
|
||||||
if ok {
|
|
||||||
return filter
|
|
||||||
}
|
|
||||||
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
|
||||||
if ok {
|
|
||||||
return filter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// default policy is DROP ALL
|
return m.applyRules(srcIP, packetData, rules, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
log.Tracef("couldn't decode layer, err: %s", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(d.decoded) < 2 {
|
||||||
|
log.Tracef("not enough levels in network packet")
|
||||||
|
return false
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
||||||
|
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return m.tcpTracker.IsValidInbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
uint16(d.tcp.SrcPort),
|
||||||
|
uint16(d.tcp.DstPort),
|
||||||
|
getTCPFlags(&d.tcp),
|
||||||
|
)
|
||||||
|
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return m.udpTracker.IsValidInbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
uint16(d.udp.SrcPort),
|
||||||
|
uint16(d.udp.DstPort),
|
||||||
|
)
|
||||||
|
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
return m.icmpTracker.IsValidInbound(
|
||||||
|
srcIP,
|
||||||
|
dstIP,
|
||||||
|
d.icmp4.Id,
|
||||||
|
d.icmp4.Seq,
|
||||||
|
d.icmp4.TypeCode.Type(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: ICMPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||||
|
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default policy: DROP ALL
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||||
|
if rulePort == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if rulePort.IsRange {
|
||||||
|
return packetPort >= rulePort.Values[0] && packetPort <= rulePort.Values[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range rulePort.Values {
|
||||||
|
if p == packetPort {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||||
payloadLayer := d.decoded[1]
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
@@ -341,13 +513,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
|||||||
|
|
||||||
switch payloadLayer {
|
switch payloadLayer {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
|
||||||
return rule.drop, true
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
@@ -357,13 +523,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
|||||||
return rule.udpHook(packetData), true
|
return rule.udpHook(packetData), true
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
|
||||||
return rule.drop, true
|
|
||||||
}
|
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
|
||||||
return rule.drop, true
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
@@ -388,9 +548,8 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
protoLayer: layers.LayerTypeUDP,
|
protoLayer: layers.LayerTypeUDP,
|
||||||
dPort: dPort,
|
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
direction: firewall.RuleDirectionOUT,
|
|
||||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||||
udpHook: hook,
|
udpHook: hook,
|
||||||
}
|
}
|
||||||
@@ -401,7 +560,6 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if in {
|
if in {
|
||||||
r.direction = firewall.RuleDirectionIN
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
}
|
}
|
||||||
@@ -420,19 +578,22 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
|
|
||||||
// RemovePacketHook removes packet hook by given ID
|
// RemovePacketHook removes packet hook by given ID
|
||||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
for _, arr := range m.incomingRules {
|
for _, arr := range m.incomingRules {
|
||||||
for _, r := range arr {
|
for _, r := range arr {
|
||||||
if r.id == hookID {
|
if r.id == hookID {
|
||||||
rule := r
|
delete(arr, r.id)
|
||||||
return m.DeletePeerRule(&rule)
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, arr := range m.outgoingRules {
|
for _, arr := range m.outgoingRules {
|
||||||
for _, r := range arr {
|
for _, r := range arr {
|
||||||
if r.id == hookID {
|
if r.id == hookID {
|
||||||
rule := r
|
delete(arr, r.id)
|
||||||
return m.DeletePeerRule(&rule)
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
998
client/firewall/uspfilter/uspfilter_bench_test.go
Normal file
998
client/firewall/uspfilter/uspfilter_bench_test.go
Normal file
@@ -0,0 +1,998 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range
|
||||||
|
func generateRandomIPs(n int) []net.IP {
|
||||||
|
ips := make([]net.IP, n)
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
for i := 0; i < n; {
|
||||||
|
ip := make(net.IP, 4)
|
||||||
|
ip[0] = 100
|
||||||
|
ip[1] = byte(64 + rand.Intn(63)) // 64-126
|
||||||
|
ip[2] = byte(rand.Intn(256))
|
||||||
|
ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255
|
||||||
|
|
||||||
|
key := ip.String()
|
||||||
|
if !seen[key] {
|
||||||
|
ips[i] = ip
|
||||||
|
seen[key] = true
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
|
||||||
|
func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
Protocol: protocol,
|
||||||
|
}
|
||||||
|
|
||||||
|
var transportLayer gopacket.SerializableLayer
|
||||||
|
switch protocol {
|
||||||
|
case layers.IPProtocolTCP:
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
SYN: true,
|
||||||
|
}
|
||||||
|
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = tcp
|
||||||
|
case layers.IPProtocolUDP:
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
|
DstPort: layers.UDPPort(dstPort),
|
||||||
|
}
|
||||||
|
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = udp
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
||||||
|
require.NoError(b, err)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkCoreFiltering focuses on the essential performance comparisons between
|
||||||
|
// stateful and stateless filtering approaches
|
||||||
|
func BenchmarkCoreFiltering(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
stateful bool
|
||||||
|
setupFunc func(*Manager)
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "stateless_single_allow_all",
|
||||||
|
stateful: false,
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
// Single rule allowing all traffic
|
||||||
|
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||||
|
fw.ActionAccept, "", "allow all")
|
||||||
|
require.NoError(b, err)
|
||||||
|
},
|
||||||
|
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_no_rules",
|
||||||
|
stateful: true,
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
// No explicit rules - rely purely on connection tracking
|
||||||
|
},
|
||||||
|
desc: "Pure connection tracking without any rules",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateless_explicit_return",
|
||||||
|
stateful: false,
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
// Add explicit rules matching return traffic pattern
|
||||||
|
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||||
|
ip := generateRandomIPs(1)[0]
|
||||||
|
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||||
|
&fw.Port{Values: []uint16{80}},
|
||||||
|
fw.ActionAccept, "", "explicit return")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
desc: "Explicit rules matching return traffic patterns without state",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_with_established",
|
||||||
|
stateful: true,
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
// Add some basic rules but rely on state for established connections
|
||||||
|
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||||
|
fw.ActionDrop, "", "default drop")
|
||||||
|
require.NoError(b, err)
|
||||||
|
},
|
||||||
|
desc: "Connection tracking with established connections",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test both TCP and UDP
|
||||||
|
protocols := []struct {
|
||||||
|
name string
|
||||||
|
proto layers.IPProtocol
|
||||||
|
}{
|
||||||
|
{"TCP", layers.IPProtocolTCP},
|
||||||
|
{"UDP", layers.IPProtocolUDP},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
for _, proto := range protocols {
|
||||||
|
b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1"))
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create manager and basic setup
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply scenario-specific setup
|
||||||
|
sc.setupFunc(manager)
|
||||||
|
|
||||||
|
// Generate test packets
|
||||||
|
srcIP := generateRandomIPs(1)[0]
|
||||||
|
dstIP := generateRandomIPs(1)[0]
|
||||||
|
srcPort := uint16(1024 + b.N%60000)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto)
|
||||||
|
inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto)
|
||||||
|
|
||||||
|
// For stateful scenarios, establish the connection
|
||||||
|
if sc.stateful {
|
||||||
|
manager.processOutgoingHooks(outbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Measure inbound packet processing
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.dropFilter(inbound, manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkStateScaling measures how performance scales with connection table size
|
||||||
|
func BenchmarkStateScaling(b *testing.B) {
|
||||||
|
connCounts := []int{100, 1000, 10000, 100000}
|
||||||
|
|
||||||
|
for _, count := range connCounts {
|
||||||
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-populate connection table
|
||||||
|
srcIPs := generateRandomIPs(count)
|
||||||
|
dstIPs := generateRandomIPs(count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
|
manager.processOutgoingHooks(outbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test packet
|
||||||
|
testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP)
|
||||||
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
|
// First establish our test connection
|
||||||
|
manager.processOutgoingHooks(testOut)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.dropFilter(testIn, manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkEstablishmentOverhead measures the overhead of connection establishment
|
||||||
|
func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
established bool
|
||||||
|
}{
|
||||||
|
{"established", true},
|
||||||
|
{"new", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := generateRandomIPs(1)[0]
|
||||||
|
dstIP := generateRandomIPs(1)[0]
|
||||||
|
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||||
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
|
if sc.established {
|
||||||
|
manager.processOutgoingHooks(outbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.dropFilter(inbound, manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic
|
||||||
|
func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
proto layers.IPProtocol
|
||||||
|
state string // "new", "established", "post_handshake" (TCP only)
|
||||||
|
setupFunc func(*Manager)
|
||||||
|
genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "allow_non_wg_tcp_new",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "new",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
},
|
||||||
|
desc: "Allow non-WG: TCP new connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow_non_wg_tcp_established",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "established",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
// Generate packets with ACK flag for established connection
|
||||||
|
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
|
||||||
|
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
desc: "Allow non-WG: TCP established connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow_non_wg_udp_new",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
state: "new",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
desc: "Allow non-WG: UDP new connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow_non_wg_udp_established",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
state: "established",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
desc: "Allow non-WG: UDP established connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_tcp_new",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "new",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
},
|
||||||
|
desc: "Stateful: TCP new connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_tcp_established",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "established",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
// Generate established TCP packets (ACK flag)
|
||||||
|
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
|
||||||
|
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
desc: "Stateful: TCP established connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_tcp_post_handshake",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
state: "post_handshake",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
// Generate packets with PSH+ACK flags for data transfer
|
||||||
|
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
desc: "Stateful: TCP post-handshake data transfer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_udp_new",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
state: "new",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
desc: "Stateful: UDP new connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_udp_established",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
state: "established",
|
||||||
|
setupFunc: func(m *Manager) {
|
||||||
|
m.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
Mask: net.CIDRMask(0, 32),
|
||||||
|
}
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
},
|
||||||
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
|
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||||
|
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
desc: "Stateful: UDP established connection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup scenario
|
||||||
|
sc.setupFunc(manager)
|
||||||
|
|
||||||
|
// Use IPs outside WG range for routed network simulation
|
||||||
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
|
outbound, inbound := sc.genPackets(srcIP, dstIP)
|
||||||
|
|
||||||
|
// For stateful cases and established connections
|
||||||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
|
manager.processOutgoingHooks(outbound)
|
||||||
|
|
||||||
|
// For TCP post-handshake, simulate full handshake
|
||||||
|
if sc.state == "post_handshake" {
|
||||||
|
// SYN
|
||||||
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
|
manager.processOutgoingHooks(syn)
|
||||||
|
// SYN-ACK
|
||||||
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
|
manager.dropFilter(synack, manager.incomingRules)
|
||||||
|
// ACK
|
||||||
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
|
manager.processOutgoingHooks(ack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.dropFilter(inbound, manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var scenarios = []struct {
|
||||||
|
name string
|
||||||
|
stateful bool // Whether conntrack is enabled
|
||||||
|
rules bool // Whether to add return traffic rules
|
||||||
|
routed bool // Whether to test routed network traffic
|
||||||
|
connCount int // Number of concurrent connections
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "stateless_with_rules_100conns",
|
||||||
|
stateful: false,
|
||||||
|
rules: true,
|
||||||
|
routed: false,
|
||||||
|
connCount: 100,
|
||||||
|
desc: "Pure stateless with return traffic rules, 100 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateless_with_rules_1000conns",
|
||||||
|
stateful: false,
|
||||||
|
rules: true,
|
||||||
|
routed: false,
|
||||||
|
connCount: 1000,
|
||||||
|
desc: "Pure stateless with return traffic rules, 1000 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_no_rules_100conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: false,
|
||||||
|
routed: false,
|
||||||
|
connCount: 100,
|
||||||
|
desc: "Pure stateful tracking without rules, 100 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_no_rules_1000conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: false,
|
||||||
|
routed: false,
|
||||||
|
connCount: 1000,
|
||||||
|
desc: "Pure stateful tracking without rules, 1000 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_with_rules_100conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: true,
|
||||||
|
routed: false,
|
||||||
|
connCount: 100,
|
||||||
|
desc: "Combined stateful + rules (current implementation), 100 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stateful_with_rules_1000conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: true,
|
||||||
|
routed: false,
|
||||||
|
connCount: 1000,
|
||||||
|
desc: "Combined stateful + rules (current implementation), 1000 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "routed_network_100conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: false,
|
||||||
|
routed: true,
|
||||||
|
connCount: 100,
|
||||||
|
desc: "Routed network traffic (non-WG), 100 conns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "routed_network_1000conns",
|
||||||
|
stateful: true,
|
||||||
|
rules: false,
|
||||||
|
routed: true,
|
||||||
|
connCount: 1000,
|
||||||
|
desc: "Routed network traffic (non-WG), 1000 conns",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns
|
||||||
|
func BenchmarkLongLivedConnections(b *testing.B) {
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup initial state based on scenario
|
||||||
|
if sc.rules {
|
||||||
|
// Single rule to allow all return traffic from port 80
|
||||||
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []uint16{80}},
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept, "", "return traffic")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate IPs for connections
|
||||||
|
srcIPs := make([]net.IP, sc.connCount)
|
||||||
|
dstIPs := make([]net.IP, sc.connCount)
|
||||||
|
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
if sc.routed {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||||
|
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||||
|
} else {
|
||||||
|
srcIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
dstIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create established connections
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
// Initial SYN
|
||||||
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
|
manager.processOutgoingHooks(syn)
|
||||||
|
|
||||||
|
// SYN-ACK
|
||||||
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
|
manager.dropFilter(synack, manager.incomingRules)
|
||||||
|
|
||||||
|
// ACK
|
||||||
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
|
manager.processOutgoingHooks(ack)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare test packets simulating bidirectional traffic
|
||||||
|
inPackets := make([][]byte, sc.connCount)
|
||||||
|
outPackets := make([][]byte, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
// Server -> Client (inbound)
|
||||||
|
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
// Client -> Server (outbound)
|
||||||
|
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
connIdx := i % sc.connCount
|
||||||
|
|
||||||
|
// Simulate bidirectional traffic
|
||||||
|
// First outbound data
|
||||||
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
|
// Then inbound response - this is what we're actually measuring
|
||||||
|
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkShortLivedConnections tests performance with many short-lived connections
|
||||||
|
func BenchmarkShortLivedConnections(b *testing.B) {
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup initial state based on scenario
|
||||||
|
if sc.rules {
|
||||||
|
// Single rule to allow all return traffic from port 80
|
||||||
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []uint16{80}},
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept, "", "return traffic")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate IPs for connections
|
||||||
|
srcIPs := make([]net.IP, sc.connCount)
|
||||||
|
dstIPs := make([]net.IP, sc.connCount)
|
||||||
|
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
if sc.routed {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||||
|
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||||
|
} else {
|
||||||
|
srcIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
dstIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create packet patterns for a complete HTTP-like short connection:
|
||||||
|
// 1. Initial handshake (SYN, SYN-ACK, ACK)
|
||||||
|
// 2. HTTP Request (PSH+ACK from client)
|
||||||
|
// 3. HTTP Response (PSH+ACK from server)
|
||||||
|
// 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK)
|
||||||
|
type connPackets struct {
|
||||||
|
syn []byte
|
||||||
|
synAck []byte
|
||||||
|
ack []byte
|
||||||
|
request []byte
|
||||||
|
response []byte
|
||||||
|
finClient []byte
|
||||||
|
ackServer []byte
|
||||||
|
finServer []byte
|
||||||
|
ackClient []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate all possible connection patterns
|
||||||
|
patterns := make([]connPackets, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
patterns[i] = connPackets{
|
||||||
|
// Handshake
|
||||||
|
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
|
||||||
|
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
|
||||||
|
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||||
|
|
||||||
|
// Data transfer
|
||||||
|
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
|
||||||
|
// Connection teardown
|
||||||
|
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||||
|
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPAck)),
|
||||||
|
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||||
|
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Each iteration creates a new short-lived connection
|
||||||
|
connIdx := i % sc.connCount
|
||||||
|
p := patterns[connIdx]
|
||||||
|
|
||||||
|
// Connection establishment
|
||||||
|
manager.processOutgoingHooks(p.syn)
|
||||||
|
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||||
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
|
// Data transfer
|
||||||
|
manager.processOutgoingHooks(p.request)
|
||||||
|
manager.dropFilter(p.response, manager.incomingRules)
|
||||||
|
|
||||||
|
// Connection teardown
|
||||||
|
manager.processOutgoingHooks(p.finClient)
|
||||||
|
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||||
|
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||||
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel
|
||||||
|
func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup initial state based on scenario
|
||||||
|
if sc.rules {
|
||||||
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []uint16{80}},
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept, "", "return traffic")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate IPs for connections
|
||||||
|
srcIPs := make([]net.IP, sc.connCount)
|
||||||
|
dstIPs := make([]net.IP, sc.connCount)
|
||||||
|
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
if sc.routed {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||||
|
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||||
|
} else {
|
||||||
|
srcIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
dstIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create established connections
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
|
manager.processOutgoingHooks(syn)
|
||||||
|
|
||||||
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
|
manager.dropFilter(synack, manager.incomingRules)
|
||||||
|
|
||||||
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
|
manager.processOutgoingHooks(ack)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-generate test packets
|
||||||
|
inPackets := make([][]byte, sc.connCount)
|
||||||
|
outPackets := make([][]byte, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
// Each goroutine gets its own counter to distribute load
|
||||||
|
counter := 0
|
||||||
|
for pb.Next() {
|
||||||
|
connIdx := counter % sc.connCount
|
||||||
|
counter++
|
||||||
|
|
||||||
|
// Simulate bidirectional traffic
|
||||||
|
manager.processOutgoingHooks(outPackets[connIdx])
|
||||||
|
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel
|
||||||
|
func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
// Configure stateful/stateless mode
|
||||||
|
if !sc.stateful {
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
} else {
|
||||||
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, _ := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
defer b.Cleanup(func() {
|
||||||
|
require.NoError(b, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
})
|
||||||
|
|
||||||
|
if sc.rules {
|
||||||
|
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||||
|
&fw.Port{Values: []uint16{80}},
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept, "", "return traffic")
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate IPs and pre-generate all packet patterns
|
||||||
|
srcIPs := make([]net.IP, sc.connCount)
|
||||||
|
dstIPs := make([]net.IP, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
if sc.routed {
|
||||||
|
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||||
|
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||||
|
} else {
|
||||||
|
srcIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
dstIPs[i] = generateRandomIPs(1)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type connPackets struct {
|
||||||
|
syn []byte
|
||||||
|
synAck []byte
|
||||||
|
ack []byte
|
||||||
|
request []byte
|
||||||
|
response []byte
|
||||||
|
finClient []byte
|
||||||
|
ackServer []byte
|
||||||
|
finServer []byte
|
||||||
|
ackClient []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
patterns := make([]connPackets, sc.connCount)
|
||||||
|
for i := 0; i < sc.connCount; i++ {
|
||||||
|
patterns[i] = connPackets{
|
||||||
|
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
|
||||||
|
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
|
||||||
|
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||||
|
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||||
|
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||||
|
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPAck)),
|
||||||
|
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
|
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||||
|
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
|
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
counter := 0
|
||||||
|
for pb.Next() {
|
||||||
|
connIdx := counter % sc.connCount
|
||||||
|
counter++
|
||||||
|
p := patterns[connIdx]
|
||||||
|
|
||||||
|
// Full connection lifecycle
|
||||||
|
manager.processOutgoingHooks(p.syn)
|
||||||
|
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||||
|
manager.processOutgoingHooks(p.ack)
|
||||||
|
|
||||||
|
manager.processOutgoingHooks(p.request)
|
||||||
|
manager.dropFilter(p.response, manager.incomingRules)
|
||||||
|
|
||||||
|
manager.processOutgoingHooks(p.finClient)
|
||||||
|
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||||
|
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||||
|
manager.processOutgoingHooks(p.ackClient)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTCPPacketWithFlags creates a TCP packet with specific flags
|
||||||
|
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
Protocol: layers.IPProtocolTCP,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set TCP flags
|
||||||
|
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
|
||||||
|
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
|
||||||
|
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
|
||||||
|
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
|
||||||
|
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
|
||||||
|
|
||||||
|
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@@ -67,12 +69,11 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []int{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -102,38 +103,16 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []int{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule 2"
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip = net.ParseIP("192.168.1.1")
|
|
||||||
proto = fw.ProtocolTCP
|
|
||||||
port = &fw.Port{Values: []int{80}}
|
|
||||||
direction = fw.RuleDirectionIN
|
|
||||||
action = fw.ActionDrop
|
|
||||||
comment = "Test rule 2"
|
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule {
|
|
||||||
err = m.DeletePeerRule(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
@@ -185,10 +164,10 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager := &Manager{
|
manager, err := Create(&IFaceMock{
|
||||||
incomingRules: map[string]RuleSet{},
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
outgoingRules: map[string]RuleSet{},
|
})
|
||||||
}
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
@@ -215,18 +194,14 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if tt.dPort != addedRule.dPort {
|
if tt.dPort != addedRule.dPort.Values[0] {
|
||||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
|
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if tt.expDir != addedRule.direction {
|
|
||||||
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if addedRule.udpHook == nil {
|
if addedRule.udpHook == nil {
|
||||||
t.Errorf("expected udpHook to be set")
|
t.Errorf("expected udpHook to be set")
|
||||||
return
|
return
|
||||||
@@ -248,12 +223,11 @@ func TestManagerReset(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []int{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -287,11 +261,10 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -313,7 +286,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
t.Errorf("failed to set network layer for checksum: %v", err)
|
t.Errorf("failed to set network layer for checksum: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payload := gopacket.Payload([]byte("test"))
|
payload := gopacket.Payload("test")
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
buf := gopacket.NewSerializeBuffer()
|
||||||
opts := gopacket.SerializeOptions{
|
opts := gopacket.SerializeOptions{
|
||||||
@@ -325,7 +298,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
|
if m.dropFilter(buf.Bytes(), m.incomingRules) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -348,6 +321,9 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Reset(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
// Add a UDP packet hook
|
// Add a UDP packet hook
|
||||||
hookFunc := func(data []byte) bool { return true }
|
hookFunc := func(data []byte) bool { return true }
|
||||||
@@ -384,6 +360,88 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
|
manager.udpTracker.Close()
|
||||||
|
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Reset(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
manager.decoders = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
return d
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
hookCalled := false
|
||||||
|
hookID := manager.AddUDPPacketHook(
|
||||||
|
false,
|
||||||
|
net.ParseIP("100.10.0.100"),
|
||||||
|
53,
|
||||||
|
func([]byte) bool {
|
||||||
|
hookCalled = true
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NotEmpty(t, hookID)
|
||||||
|
|
||||||
|
// Create test UDP packet
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: net.ParseIP("100.10.0.1"),
|
||||||
|
DstIP: net.ParseIP("100.10.0.100"),
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: 51334,
|
||||||
|
DstPort: 53,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = udp.SetNetworkLayerForChecksum(ipv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
payload := gopacket.Payload("test")
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test hook gets called
|
||||||
|
result := manager.processOutgoingHooks(buf.Bytes())
|
||||||
|
require.True(t, result)
|
||||||
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
|
// Test non-UDP packet is ignored
|
||||||
|
ipv4.Protocol = layers.IPProtocolTCP
|
||||||
|
buf = gopacket.NewSerializeBuffer()
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result = manager.processOutgoingHooks(buf.Bytes())
|
||||||
|
require.False(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
@@ -405,12 +463,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
if i%2 == 0 {
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
@@ -418,3 +472,213 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
|
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
||||||
|
manager.decoders = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
return d
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Reset(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set up packet parameters
|
||||||
|
srcIP := net.ParseIP("100.10.0.1")
|
||||||
|
dstIP := net.ParseIP("100.10.0.100")
|
||||||
|
srcPort := uint16(51334)
|
||||||
|
dstPort := uint16(53)
|
||||||
|
|
||||||
|
// Create outbound packet
|
||||||
|
outboundIPv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
outboundUDP := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
|
DstPort: layers.UDPPort(dstPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
outboundBuf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = gopacket.SerializeLayers(outboundBuf, opts,
|
||||||
|
outboundIPv4,
|
||||||
|
outboundUDP,
|
||||||
|
gopacket.Payload("test"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Process outbound packet and verify connection tracking
|
||||||
|
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
||||||
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
|
// Verify connection was tracked
|
||||||
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||||
|
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
||||||
|
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
||||||
|
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||||
|
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||||
|
|
||||||
|
// Create valid inbound response packet
|
||||||
|
inboundIPv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: dstIP, // Original destination is now source
|
||||||
|
DstIP: srcIP, // Original source is now destination
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
inboundUDP := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(dstPort), // Original destination port is now source
|
||||||
|
DstPort: layers.UDPPort(srcPort), // Original source port is now destination
|
||||||
|
}
|
||||||
|
|
||||||
|
err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
inboundBuf := gopacket.NewSerializeBuffer()
|
||||||
|
err = gopacket.SerializeLayers(inboundBuf, opts,
|
||||||
|
inboundIPv4,
|
||||||
|
inboundUDP,
|
||||||
|
gopacket.Payload("response"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Test roundtrip response handling over time
|
||||||
|
checkPoints := []struct {
|
||||||
|
sleep time.Duration
|
||||||
|
shouldAllow bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sleep: 0,
|
||||||
|
shouldAllow: true,
|
||||||
|
description: "Immediate response should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sleep: 50 * time.Millisecond,
|
||||||
|
shouldAllow: true,
|
||||||
|
description: "Response within timeout should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sleep: 100 * time.Millisecond,
|
||||||
|
shouldAllow: true,
|
||||||
|
description: "Response at half timeout should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// tracker hasn't updated conn for 250ms -> greater than 200ms timeout
|
||||||
|
sleep: 250 * time.Millisecond,
|
||||||
|
shouldAllow: false,
|
||||||
|
description: "Response after timeout should be dropped",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cp := range checkPoints {
|
||||||
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
|
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
|
||||||
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
|
// If the connection should still be valid, verify it exists
|
||||||
|
if cp.shouldAllow {
|
||||||
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
require.True(t, exists, "Connection should still exist during valid window")
|
||||||
|
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
||||||
|
"LastSeen should be updated for valid responses")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid response packets (while connection is expired)
|
||||||
|
invalidCases := []struct {
|
||||||
|
name string
|
||||||
|
modifyFunc func(*layers.IPv4, *layers.UDP)
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wrong source IP",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
ip.SrcIP = net.ParseIP("100.10.0.101")
|
||||||
|
},
|
||||||
|
description: "Response from wrong IP should be dropped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong destination IP",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
ip.DstIP = net.ParseIP("100.10.0.2")
|
||||||
|
},
|
||||||
|
description: "Response to wrong IP should be dropped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong source port",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
udp.SrcPort = 54
|
||||||
|
},
|
||||||
|
description: "Response from wrong port should be dropped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong destination port",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
udp.DstPort = 51335
|
||||||
|
},
|
||||||
|
description: "Response to wrong port should be dropped",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new outbound connection for invalid tests
|
||||||
|
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
||||||
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
|
for _, tc := range invalidCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testIPv4 := *inboundIPv4
|
||||||
|
testUDP := *inboundUDP
|
||||||
|
|
||||||
|
tc.modifyFunc(&testIPv4, &testUDP)
|
||||||
|
|
||||||
|
err = testUDP.SetNetworkLayerForChecksum(&testIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testBuf := gopacket.NewSerializeBuffer()
|
||||||
|
err = gopacket.SerializeLayers(testBuf, opts,
|
||||||
|
&testIPv4,
|
||||||
|
&testUDP,
|
||||||
|
gopacket.Payload("response"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the invalid packet is dropped
|
||||||
|
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
|
||||||
|
require.True(t, drop, tc.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,8 +33,6 @@ type TunKernelDevice struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||||
checkUser()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &TunKernelDevice{
|
return &TunKernelDevice{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
@@ -32,8 +30,6 @@ type USPDevice struct {
|
|||||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||||
log.Infof("using userspace bind mode")
|
log.Infof("using userspace bind mode")
|
||||||
|
|
||||||
checkUser()
|
|
||||||
|
|
||||||
return &USPDevice{
|
return &USPDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -134,12 +130,3 @@ func (t *USPDevice) assignAddr() error {
|
|||||||
|
|
||||||
return link.assignAddr(t.address)
|
return link.assignAddr(t.address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkUser() {
|
|
||||||
if runtime.GOOS == "freebsd" {
|
|
||||||
euid := os.Geteuid()
|
|
||||||
if euid != 0 {
|
|
||||||
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -203,6 +203,11 @@ func (l *Link) setAddr(ip, netmask string) error {
|
|||||||
return fmt.Errorf("set interface addr: %w", err)
|
return fmt.Errorf("set interface addr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command("ifconfig", l.name, "inet6", "fe80::/64")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,10 @@ func IsEnabled() bool {
|
|||||||
|
|
||||||
func ListenAddr() string {
|
func ListenAddr() string {
|
||||||
sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT")
|
sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT")
|
||||||
|
if sPort == "" {
|
||||||
|
return listenAddr(DefaultSocks5Port)
|
||||||
|
}
|
||||||
|
|
||||||
port, err := strconv.Atoi(sPort)
|
port, err := strconv.Atoi(sPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port)
|
log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port)
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
d.rollBack(newRulePairs)
|
d.rollBack(newRulePairs)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if len(rules) > 0 {
|
if len(rulePair) > 0 {
|
||||||
d.peerRulesPairs[pairID] = rulePair
|
d.peerRulesPairs[pairID] = rulePair
|
||||||
newRulePairs[pairID] = rulePair
|
newRulePairs[pairID] = rulePair
|
||||||
}
|
}
|
||||||
@@ -268,13 +268,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var port *firewall.Port
|
var port *firewall.Port
|
||||||
if r.Port != "" {
|
if r.PortInfo != nil {
|
||||||
|
port = convertPortInfo(r.PortInfo)
|
||||||
|
} else if r.Port != "" {
|
||||||
|
// old version of management, single port
|
||||||
value, err := strconv.Atoi(r.Port)
|
value, err := strconv.Atoi(r.Port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid port: %w", err)
|
||||||
}
|
}
|
||||||
port = &firewall.Port{
|
port = &firewall.Port{
|
||||||
Values: []int{value},
|
Values: []uint16{uint16(value)},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -288,6 +291,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
|
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||||
|
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
@@ -308,25 +313,12 @@ func (d *DefaultManager) addInRules(
|
|||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
||||||
rule, err := d.firewall.AddPeerFiltering(
|
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
|
||||||
rules = append(rules, rule...)
|
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
|
||||||
return rules, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.firewall.AddPeerFiltering(
|
return rule, nil
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, rule...), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
func (d *DefaultManager) addOutRules(
|
||||||
@@ -337,25 +329,16 @@ func (d *DefaultManager) addOutRules(
|
|||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(
|
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
|
||||||
}
|
|
||||||
rules = append(rules, rule...)
|
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
return rules, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.firewall.AddPeerFiltering(
|
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(rules, rule...), nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||||
@@ -559,14 +542,14 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
|||||||
|
|
||||||
if portInfo.GetPort() != 0 {
|
if portInfo.GetPort() != 0 {
|
||||||
return &firewall.Port{
|
return &firewall.Port{
|
||||||
Values: []int{int(portInfo.GetPort())},
|
Values: []uint16{uint16(int(portInfo.GetPort()))},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if portInfo.GetRange() != nil {
|
if portInfo.GetRange() != nil {
|
||||||
return &firewall.Port{
|
return &firewall.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
|
Values: []uint16{uint16(portInfo.GetRange().Start), uint16(portInfo.GetRange().End)},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -119,8 +119,8 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
networkMap.FirewallRulesIsEmpty = false
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if len(acl.peerRulesPairs) != 1 {
|
||||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -356,8 +356,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 4 {
|
if len(acl.peerRulesPairs) != 3 {
|
||||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,13 @@ type ConfigInput struct {
|
|||||||
DNSRouteInterval *time.Duration
|
DNSRouteInterval *time.Duration
|
||||||
ClientCertPath string
|
ClientCertPath string
|
||||||
ClientCertKeyPath string
|
ClientCertKeyPath string
|
||||||
|
|
||||||
|
DisableClientRoutes *bool
|
||||||
|
DisableServerRoutes *bool
|
||||||
|
DisableDNS *bool
|
||||||
|
DisableFirewall *bool
|
||||||
|
|
||||||
|
BlockLANAccess *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@@ -78,6 +85,14 @@ type Config struct {
|
|||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
|
||||||
|
DisableClientRoutes bool
|
||||||
|
DisableServerRoutes bool
|
||||||
|
DisableDNS bool
|
||||||
|
DisableFirewall bool
|
||||||
|
|
||||||
|
BlockLANAccess bool
|
||||||
|
|
||||||
// SSHKey is a private SSH key in a PEM format
|
// SSHKey is a private SSH key in a PEM format
|
||||||
SSHKey string
|
SSHKey string
|
||||||
|
|
||||||
@@ -402,7 +417,56 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
config.DNSRouteInterval = dynamic.DefaultInterval
|
config.DNSRouteInterval = dynamic.DefaultInterval
|
||||||
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
||||||
updated = true
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableClientRoutes != nil && *input.DisableClientRoutes != config.DisableClientRoutes {
|
||||||
|
if *input.DisableClientRoutes {
|
||||||
|
log.Infof("disabling client routes")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling client routes")
|
||||||
|
}
|
||||||
|
config.DisableClientRoutes = *input.DisableClientRoutes
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableServerRoutes != nil && *input.DisableServerRoutes != config.DisableServerRoutes {
|
||||||
|
if *input.DisableServerRoutes {
|
||||||
|
log.Infof("disabling server routes")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling server routes")
|
||||||
|
}
|
||||||
|
config.DisableServerRoutes = *input.DisableServerRoutes
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableDNS != nil && *input.DisableDNS != config.DisableDNS {
|
||||||
|
if *input.DisableDNS {
|
||||||
|
log.Infof("disabling DNS configuration")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling DNS configuration")
|
||||||
|
}
|
||||||
|
config.DisableDNS = *input.DisableDNS
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableFirewall != nil && *input.DisableFirewall != config.DisableFirewall {
|
||||||
|
if *input.DisableFirewall {
|
||||||
|
log.Infof("disabling firewall configuration")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling firewall configuration")
|
||||||
|
}
|
||||||
|
config.DisableFirewall = *input.DisableFirewall
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.BlockLANAccess != nil && *input.BlockLANAccess != config.BlockLANAccess {
|
||||||
|
if *input.BlockLANAccess {
|
||||||
|
log.Infof("blocking LAN access")
|
||||||
|
} else {
|
||||||
|
log.Infof("allowing LAN access")
|
||||||
|
}
|
||||||
|
config.BlockLANAccess = *input.BlockLANAccess
|
||||||
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.ClientCertKeyPath != "" {
|
if input.ClientCertKeyPath != "" {
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
|
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
|
||||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey)
|
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug(err)
|
log.Debug(err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
@@ -382,8 +382,7 @@ func (c *ConnectClient) isContextCancelled() bool {
|
|||||||
// SetNetworkMapPersistence enables or disables network map persistence.
|
// SetNetworkMapPersistence enables or disables network map persistence.
|
||||||
// When enabled, the last received network map will be stored and can be retrieved
|
// When enabled, the last received network map will be stored and can be retrieved
|
||||||
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
||||||
// network map will be cleared. This functionality is primarily used for debugging
|
// network map will be cleared.
|
||||||
// and should not be enabled during normal operation.
|
|
||||||
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
|
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.persistNetworkMap = enabled
|
c.persistNetworkMap = enabled
|
||||||
@@ -416,6 +415,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
RosenpassPermissive: config.RosenpassPermissive,
|
RosenpassPermissive: config.RosenpassPermissive,
|
||||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||||
DNSRouteInterval: config.DNSRouteInterval,
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
|
DisableServerRoutes: config.DisableServerRoutes,
|
||||||
|
DisableDNS: config.DisableDNS,
|
||||||
|
DisableFirewall: config.DisableFirewall,
|
||||||
|
|
||||||
|
BlockLANAccess: config.BlockLANAccess,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
@@ -457,7 +463,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -465,6 +471,15 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
|
|||||||
}
|
}
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
|
sysInfo.SetFlags(
|
||||||
|
config.RosenpassEnabled,
|
||||||
|
config.RosenpassPermissive,
|
||||||
|
config.ServerSSHAllowed,
|
||||||
|
config.DisableClientRoutes,
|
||||||
|
config.DisableServerRoutes,
|
||||||
|
config.DisableDNS,
|
||||||
|
config.DisableFirewall,
|
||||||
|
)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
18
client/internal/dns/consts.go
Normal file
18
client/internal/dns/consts.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
var fileUncleanShutdownResolvConfLocation string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
fileUncleanShutdownResolvConfLocation = os.Getenv("NB_UNCLEAN_SHUTDOWN_RESOLV_FILE")
|
||||||
|
if fileUncleanShutdownResolvConfLocation == "" {
|
||||||
|
fileUncleanShutdownResolvConfLocation = filepath.Join(configs.StateDir, "resolv.conf")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
const (
|
|
||||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
|
||||||
)
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
const (
|
|
||||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
|
||||||
)
|
|
||||||
238
client/internal/dns/handler_chain.go
Normal file
238
client/internal/dns/handler_chain.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PriorityDNSRoute = 100
|
||||||
|
PriorityMatchDomain = 50
|
||||||
|
PriorityDefault = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type SubdomainMatcher interface {
|
||||||
|
dns.Handler
|
||||||
|
MatchSubdomains() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type HandlerEntry struct {
|
||||||
|
Handler dns.Handler
|
||||||
|
Priority int
|
||||||
|
Pattern string
|
||||||
|
OrigPattern string
|
||||||
|
IsWildcard bool
|
||||||
|
StopHandler handlerWithStop
|
||||||
|
MatchSubdomains bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerChain represents a prioritized chain of DNS handlers
|
||||||
|
type HandlerChain struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
handlers []HandlerEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||||
|
type ResponseWriterChain struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
origPattern string
|
||||||
|
shouldContinue bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||||
|
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
|
||||||
|
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
|
||||||
|
w.shouldContinue = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.ResponseWriter.WriteMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandlerChain() *HandlerChain {
|
||||||
|
return &HandlerChain{
|
||||||
|
handlers: make([]HandlerEntry, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrigPattern returns the original pattern of the handler that wrote the response
|
||||||
|
func (w *ResponseWriterChain) GetOrigPattern() string {
|
||||||
|
return w.origPattern
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||||
|
origPattern := pattern
|
||||||
|
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||||
|
if isWildcard {
|
||||||
|
pattern = pattern[2:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
|
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||||
|
if c.handlers[i].StopHandler != nil {
|
||||||
|
c.handlers[i].StopHandler.stop()
|
||||||
|
}
|
||||||
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if handler implements SubdomainMatcher interface
|
||||||
|
matchSubdomains := false
|
||||||
|
if matcher, ok := handler.(SubdomainMatcher); ok {
|
||||||
|
matchSubdomains = matcher.MatchSubdomains()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
|
pattern, origPattern, isWildcard, matchSubdomains, priority)
|
||||||
|
|
||||||
|
entry := HandlerEntry{
|
||||||
|
Handler: handler,
|
||||||
|
Priority: priority,
|
||||||
|
Pattern: pattern,
|
||||||
|
OrigPattern: origPattern,
|
||||||
|
IsWildcard: isWildcard,
|
||||||
|
StopHandler: stopHandler,
|
||||||
|
MatchSubdomains: matchSubdomains,
|
||||||
|
}
|
||||||
|
|
||||||
|
pos := c.findHandlerPosition(entry)
|
||||||
|
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||||
|
func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
||||||
|
for i, h := range c.handlers {
|
||||||
|
// prio first
|
||||||
|
if h.Priority < newEntry.Priority {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// domain specificity next
|
||||||
|
if h.Priority == newEntry.Priority {
|
||||||
|
newDots := strings.Count(newEntry.Pattern, ".")
|
||||||
|
existingDots := strings.Count(h.Pattern, ".")
|
||||||
|
if newDots > existingDots {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add at end
|
||||||
|
return len(c.handlers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveHandler removes a handler for the given pattern and priority
|
||||||
|
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
pattern = dns.Fqdn(pattern)
|
||||||
|
|
||||||
|
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||||
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
|
entry := c.handlers[i]
|
||||||
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
|
if entry.StopHandler != nil {
|
||||||
|
entry.StopHandler.stop()
|
||||||
|
}
|
||||||
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
||||||
|
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||||
|
for _, entry := range c.handlers {
|
||||||
|
if strings.EqualFold(entry.Pattern, pattern) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
qname := strings.ToLower(r.Question[0].Name)
|
||||||
|
log.Tracef("handling DNS request for domain=%s", qname)
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
handlers := slices.Clone(c.handlers)
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
if log.IsLevelEnabled(log.TraceLevel) {
|
||||||
|
log.Tracef("current handlers (%d):", len(handlers))
|
||||||
|
for _, h := range handlers {
|
||||||
|
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
||||||
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try handlers in priority order
|
||||||
|
for _, entry := range handlers {
|
||||||
|
var matched bool
|
||||||
|
switch {
|
||||||
|
case entry.Pattern == ".":
|
||||||
|
matched = true
|
||||||
|
case entry.IsWildcard:
|
||||||
|
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||||
|
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||||
|
default:
|
||||||
|
// For non-wildcard patterns:
|
||||||
|
// If handler wants subdomain matching, allow suffix match
|
||||||
|
// Otherwise require exact match
|
||||||
|
if entry.MatchSubdomains {
|
||||||
|
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
|
} else {
|
||||||
|
matched = strings.EqualFold(qname, entry.Pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
||||||
|
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
||||||
|
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
||||||
|
|
||||||
|
chainWriter := &ResponseWriterChain{
|
||||||
|
ResponseWriter: w,
|
||||||
|
origPattern: entry.OrigPattern,
|
||||||
|
}
|
||||||
|
entry.Handler.ServeDNS(chainWriter, r)
|
||||||
|
|
||||||
|
// If handler wants to continue, try next handler
|
||||||
|
if chainWriter.shouldContinue {
|
||||||
|
log.Tracef("handler requested continue to next handler")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No handler matched or all handlers passed
|
||||||
|
log.Tracef("no handler found for domain=%s", qname)
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
832
client/internal/dns/handler_chain_test.go
Normal file
832
client/internal/dns/handler_chain_test.go
Normal file
@@ -0,0 +1,832 @@
|
|||||||
|
package dns_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||||
|
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
// Create mock handlers for different priorities
|
||||||
|
defaultHandler := &nbdns.MockHandler{}
|
||||||
|
matchDomainHandler := &nbdns.MockHandler{}
|
||||||
|
dnsRouteHandler := &nbdns.MockHandler{}
|
||||||
|
|
||||||
|
// Setup handlers with different priorities
|
||||||
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||||
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
||||||
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
// Create test writer
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// Setup expectations - only highest priority handler should be called
|
||||||
|
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify all expectations were met
|
||||||
|
dnsRouteHandler.AssertExpectations(t)
|
||||||
|
matchDomainHandler.AssertExpectations(t)
|
||||||
|
defaultHandler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
|
||||||
|
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handlerDomain string
|
||||||
|
queryDomain string
|
||||||
|
isWildcard bool
|
||||||
|
matchSubdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact match",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with non-wildcard and MatchSubdomains true",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with non-wildcard and MatchSubdomains false",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard match",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard no match on apex",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone match",
|
||||||
|
handlerDomain: ".",
|
||||||
|
queryDomain: "anything.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match different domain",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "example.org.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
var handler dns.Handler
|
||||||
|
|
||||||
|
if tt.matchSubdomains {
|
||||||
|
mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
handler = mockSubHandler
|
||||||
|
if tt.shouldMatch {
|
||||||
|
mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mockHandler := &nbdns.MockHandler{}
|
||||||
|
handler = mockHandler
|
||||||
|
if tt.shouldMatch {
|
||||||
|
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern := tt.handlerDomain
|
||||||
|
if tt.isWildcard {
|
||||||
|
pattern = "*." + tt.handlerDomain[2:]
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
if h, ok := handler.(*nbdns.MockHandler); ok {
|
||||||
|
h.AssertExpectations(t)
|
||||||
|
} else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok {
|
||||||
|
h.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
|
||||||
|
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handlers []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
queryDomain string
|
||||||
|
expectedCalls int
|
||||||
|
expectedHandler int // index of the handler that should be called
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wildcard and exact same priority - exact should win",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
},
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // exact match handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "higher priority wildcard over lower priority exact",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // higher priority wildcard handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple wildcards different priorities",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 2, // highest priority handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with mix of patterns",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||||
|
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "sub.test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 2, // highest priority matching handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone with specific domain",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: ".", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // higher priority specific domain should win over root
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
var handlers []*nbdns.MockHandler
|
||||||
|
|
||||||
|
// Setup handlers and expectations
|
||||||
|
for i := range tt.handlers {
|
||||||
|
handler := &nbdns.MockHandler{}
|
||||||
|
handlers = append(handlers, handler)
|
||||||
|
|
||||||
|
// Set expectation based on whether this handler should be called
|
||||||
|
if i == tt.expectedHandler {
|
||||||
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
} else {
|
||||||
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and execute request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify expectations
|
||||||
|
for _, handler := range handlers {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
|
||||||
|
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
// Create handlers
|
||||||
|
handler1 := &nbdns.MockHandler{}
|
||||||
|
handler2 := &nbdns.MockHandler{}
|
||||||
|
handler3 := &nbdns.MockHandler{}
|
||||||
|
|
||||||
|
// Add handlers in priority order
|
||||||
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
||||||
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
// Setup mock responses to simulate chain continuation
|
||||||
|
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// First handler signals continue
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
resp.MsgHdr.Zero = true // Signal to continue
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// Second handler signals continue
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
resp.MsgHdr.Zero = true
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// Last handler responds normally
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify all handlers were called in order
|
||||||
|
handler1.AssertExpectations(t)
|
||||||
|
handler2.AssertExpectations(t)
|
||||||
|
handler3.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockResponseWriter implements dns.ResponseWriter for testing
|
||||||
|
type mockResponseWriter struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
|
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
|
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||||
|
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
func (m *mockResponseWriter) Close() error { return nil }
|
||||||
|
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
||||||
|
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||||
|
func (m *mockResponseWriter) Hijack() {}
|
||||||
|
|
||||||
|
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ops []struct {
|
||||||
|
action string // "add" or "remove"
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
query string
|
||||||
|
expectedCalls map[int]bool // map[priority]shouldBeCalled
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "remove high priority keeps lower priority handler",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: map[int]bool{
|
||||||
|
nbdns.PriorityDNSRoute: false,
|
||||||
|
nbdns.PriorityMatchDomain: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remove lower priority keeps high priority handler",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: map[int]bool{
|
||||||
|
nbdns.PriorityDNSRoute: true,
|
||||||
|
nbdns.PriorityMatchDomain: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remove all handlers in order",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
{"add", "example.com.", nbdns.PriorityDefault},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: map[int]bool{
|
||||||
|
nbdns.PriorityDNSRoute: false,
|
||||||
|
nbdns.PriorityMatchDomain: false,
|
||||||
|
nbdns.PriorityDefault: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
handlers := make(map[int]*nbdns.MockHandler)
|
||||||
|
|
||||||
|
// Execute operations
|
||||||
|
for _, op := range tt.ops {
|
||||||
|
if op.action == "add" {
|
||||||
|
handler := &nbdns.MockHandler{}
|
||||||
|
handlers[op.priority] = handler
|
||||||
|
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||||
|
} else {
|
||||||
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// Setup expectations
|
||||||
|
for priority, handler := range handlers {
|
||||||
|
if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
} else {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify expectations
|
||||||
|
for _, handler := range handlers {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handler exists check
|
||||||
|
for priority, shouldExist := range tt.expectedCalls {
|
||||||
|
if shouldExist {
|
||||||
|
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
||||||
|
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
testDomain := "example.com."
|
||||||
|
testQuery := "test.example.com."
|
||||||
|
|
||||||
|
// Create handlers with MatchSubdomains enabled
|
||||||
|
routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
|
||||||
|
// Create test request that will be reused
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
|
// Add handlers in mixed order
|
||||||
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
||||||
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
||||||
|
|
||||||
|
// Test 1: Initial state with all three handlers
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
// Highest priority handler (routeHandler) should be called
|
||||||
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
routeHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Test 2: Remove highest priority handler
|
||||||
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||||
|
assert.True(t, chain.HasHandlers(testDomain))
|
||||||
|
|
||||||
|
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
// Now middle priority handler (matchHandler) should be called
|
||||||
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
matchHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Test 3: Remove middle priority handler
|
||||||
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||||
|
assert.True(t, chain.HasHandlers(testDomain))
|
||||||
|
|
||||||
|
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
defaultHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Test 4: Remove last handler
|
||||||
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
assert.False(t, chain.HasHandlers(testDomain))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
scenario string
|
||||||
|
addHandlers []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}
|
||||||
|
query string
|
||||||
|
expectedCalls int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "case insensitive exact match",
|
||||||
|
scenario: "handler registered lowercase, query uppercase",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive wildcard match",
|
||||||
|
scenario: "handler registered mixed case wildcard, query different case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "sub.EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple handlers different case same domain",
|
||||||
|
scenario: "second handler should replace first despite case difference",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
|
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "ExAmPlE.cOm.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain matching case insensitive",
|
||||||
|
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"example.com.", nbdns.PriorityDefault, true, true},
|
||||||
|
},
|
||||||
|
query: "SUB.EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone case insensitive",
|
||||||
|
scenario: "root zone handler should match regardless of case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{".", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple handlers different priority",
|
||||||
|
scenario: "should call higher priority handler despite case differences",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
|
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
||||||
|
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
handlerCalls := make(map[string]bool) // track which patterns were called
|
||||||
|
|
||||||
|
// Add handlers according to test case
|
||||||
|
for _, h := range tt.addHandlers {
|
||||||
|
var handler dns.Handler
|
||||||
|
pattern := h.pattern // capture pattern for closure
|
||||||
|
|
||||||
|
if h.subdomains {
|
||||||
|
subHandler := &nbdns.MockSubdomainHandler{
|
||||||
|
Subdomains: true,
|
||||||
|
}
|
||||||
|
if h.shouldMatch {
|
||||||
|
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
handlerCalls[pattern] = true
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
handler = subHandler
|
||||||
|
} else {
|
||||||
|
mockHandler := &nbdns.MockHandler{}
|
||||||
|
if h.shouldMatch {
|
||||||
|
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
handlerCalls[pattern] = true
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
handler = mockHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
|
chain.ServeDNS(&mockResponseWriter{}, r)
|
||||||
|
|
||||||
|
// Verify each handler was called exactly as expected
|
||||||
|
for _, h := range tt.addHandlers {
|
||||||
|
wasCalled := handlerCalls[h.pattern]
|
||||||
|
assert.Equal(t, h.shouldMatch, wasCalled,
|
||||||
|
"Handler for pattern %q was %s when it should%s have been",
|
||||||
|
h.pattern,
|
||||||
|
map[bool]string{true: "called", false: "not called"}[wasCalled],
|
||||||
|
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify total number of calls
|
||||||
|
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
|
||||||
|
"Wrong number of total handler calls")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
scenario string
|
||||||
|
ops []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}
|
||||||
|
query string
|
||||||
|
expectedMatch string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "more specific domain matches first",
|
||||||
|
scenario: "sub.example.com should match before example.com",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "more specific domain matches first, both match subdomains",
|
||||||
|
scenario: "sub.example.com should match before example.com",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "maintain specificity order after removal",
|
||||||
|
scenario: "after removing most specific, should fall back to less specific",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
},
|
||||||
|
query: "test.sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "priority overrides specificity",
|
||||||
|
scenario: "less specific domain with higher priority should match first",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "equal priority respects specificity",
|
||||||
|
scenario: "with equal priority, more specific domain should match",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "specific matches before wildcard",
|
||||||
|
scenario: "specific domain should match before wildcard at same priority",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomain bool
|
||||||
|
}{
|
||||||
|
{"add", "*.example.com.", nbdns.PriorityDNSRoute, false},
|
||||||
|
{"add", "sub.example.com.", nbdns.PriorityDNSRoute, false},
|
||||||
|
},
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedMatch: "sub.example.com.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
handlers := make(map[string]*nbdns.MockSubdomainHandler)
|
||||||
|
|
||||||
|
for _, op := range tt.ops {
|
||||||
|
if op.action == "add" {
|
||||||
|
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
||||||
|
handlers[op.pattern] = handler
|
||||||
|
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||||
|
} else {
|
||||||
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// Setup handler expectations
|
||||||
|
for pattern, handler := range handlers {
|
||||||
|
if pattern == tt.expectedMatch {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetReply(r)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
for pattern, handler := range handlers {
|
||||||
|
if pattern == tt.expectedMatch {
|
||||||
|
handler.AssertNumberOfCalls(t, "ServeDNS", 1)
|
||||||
|
} else {
|
||||||
|
handler.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -102,3 +102,17 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type noopHostConfigurator struct{}
|
||||||
|
|
||||||
|
func (n noopHostConfigurator) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n noopHostConfigurator) restoreHostDNS() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n noopHostConfigurator) supportCustomPort() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ const (
|
|||||||
arraySymbol = "* "
|
arraySymbol = "* "
|
||||||
digitSymbol = "# "
|
digitSymbol = "# "
|
||||||
scutilPath = "/usr/sbin/scutil"
|
scutilPath = "/usr/sbin/scutil"
|
||||||
|
dscacheutilPath = "/usr/bin/dscacheutil"
|
||||||
searchSuffix = "Search"
|
searchSuffix = "Search"
|
||||||
matchSuffix = "Match"
|
matchSuffix = "Match"
|
||||||
localSuffix = "Local"
|
localSuffix = "Local"
|
||||||
@@ -106,6 +107,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
return fmt.Errorf("add search domains: %w", err)
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,6 +128,10 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -316,6 +325,21 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
|||||||
return primaryService, router, nil
|
return primaryService, router, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *systemConfigurator) flushDNSCache() error {
|
||||||
|
cmd := exec.Command(dscacheutilPath, "-flushcache")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("flush DNS cache: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = exec.Command("killall", "-HUP", "mDNSResponder")
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("flushed DNS cache")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
|
func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
|
||||||
if err := s.restoreHostDNS(); err != nil {
|
if err := s.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via scutil: %w", err)
|
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||||
|
|||||||
@@ -48,11 +48,17 @@ type restoreHostManager interface {
|
|||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("System DNS manager discovered: %s", osManager)
|
log.Infof("System DNS manager discovered: %s", osManager)
|
||||||
return newHostManagerFromType(wgInterface, osManager)
|
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||||
|
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create host manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
|
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
|
||||||
|
|||||||
@@ -17,12 +17,24 @@ type localResolver struct {
|
|||||||
records sync.Map
|
records sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *localResolver) MatchSubdomains() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (d *localResolver) stop() {
|
func (d *localResolver) stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the local resolver
|
||||||
|
func (d *localResolver) String() string {
|
||||||
|
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||||
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
log.Tracef("received question: %#v", r.Question[0])
|
if len(r.Question) > 0 {
|
||||||
|
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
}
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
replyMessage.SetReply(r)
|
replyMessage.SetReply(r)
|
||||||
replyMessage.RecursionAvailable = true
|
replyMessage.RecursionAvailable = true
|
||||||
|
|||||||
@@ -3,14 +3,30 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockServer is the mock instance of a dns server
|
// MockServer is the mock instance of a dns server
|
||||||
type MockServer struct {
|
type MockServer struct {
|
||||||
InitializeFunc func() error
|
InitializeFunc func() error
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
|
RegisterHandlerFunc func([]string, dns.Handler, int)
|
||||||
|
DeregisterHandlerFunc func([]string, int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
if m.RegisterHandlerFunc != nil {
|
||||||
|
m.RegisterHandlerFunc(domains, handler, priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
||||||
|
if m.DeregisterHandlerFunc != nil {
|
||||||
|
m.DeregisterHandlerFunc(domains, priority)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize mock implementation of Initialize from Server interface
|
// Initialize mock implementation of Initialize from Server interface
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -15,23 +16,64 @@ import (
|
|||||||
|
|
||||||
const resolvconfCommand = "resolvconf"
|
const resolvconfCommand = "resolvconf"
|
||||||
|
|
||||||
|
// resolvconfType represents the type of resolvconf implementation
|
||||||
|
type resolvconfType int
|
||||||
|
|
||||||
|
func (r resolvconfType) String() string {
|
||||||
|
switch r {
|
||||||
|
case typeOpenresolv:
|
||||||
|
return "openresolv"
|
||||||
|
case typeResolvconf:
|
||||||
|
return "resolvconf"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
typeOpenresolv resolvconfType = iota
|
||||||
|
typeResolvconf
|
||||||
|
)
|
||||||
|
|
||||||
type resolvconf struct {
|
type resolvconf struct {
|
||||||
ifaceName string
|
ifaceName string
|
||||||
|
implType resolvconfType
|
||||||
|
|
||||||
originalSearchDomains []string
|
originalSearchDomains []string
|
||||||
originalNameServers []string
|
originalNameServers []string
|
||||||
othersConfigs []string
|
othersConfigs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// supported "openresolv" only
|
func detectResolvconfType() (resolvconfType, error) {
|
||||||
|
cmd := exec.Command(resolvconfCommand, "--version")
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return typeOpenresolv, fmt.Errorf("failed to determine resolvconf type: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(string(out), "openresolv") {
|
||||||
|
return typeOpenresolv, nil
|
||||||
|
}
|
||||||
|
return typeResolvconf, nil
|
||||||
|
}
|
||||||
|
|
||||||
func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
|
func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
|
||||||
resolvConfEntries, err := parseDefaultResolvConf()
|
resolvConfEntries, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
implType, err := detectResolvconfType()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to detect resolvconf type, defaulting to openresolv: %v", err)
|
||||||
|
implType = typeOpenresolv
|
||||||
|
} else {
|
||||||
|
log.Infof("detected resolvconf type: %v", implType)
|
||||||
|
}
|
||||||
|
|
||||||
return &resolvconf{
|
return &resolvconf{
|
||||||
ifaceName: wgInterface,
|
ifaceName: wgInterface,
|
||||||
|
implType: implType,
|
||||||
originalSearchDomains: resolvConfEntries.searchDomains,
|
originalSearchDomains: resolvConfEntries.searchDomains,
|
||||||
originalNameServers: resolvConfEntries.nameServers,
|
originalNameServers: resolvConfEntries.nameServers,
|
||||||
othersConfigs: resolvConfEntries.others,
|
othersConfigs: resolvConfEntries.others,
|
||||||
@@ -80,8 +122,15 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) restoreHostDNS() error {
|
func (r *resolvconf) restoreHostDNS() error {
|
||||||
// openresolv only, debian resolvconf doesn't support "-f"
|
var cmd *exec.Cmd
|
||||||
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
|
||||||
|
switch r.implType {
|
||||||
|
case typeOpenresolv:
|
||||||
|
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||||
|
case typeResolvconf:
|
||||||
|
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
|
||||||
|
}
|
||||||
|
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||||
@@ -91,10 +140,21 @@ func (r *resolvconf) restoreHostDNS() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
||||||
// openresolv only, debian resolvconf doesn't support "-x"
|
var cmd *exec.Cmd
|
||||||
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
|
||||||
|
switch r.implType {
|
||||||
|
case typeOpenresolv:
|
||||||
|
// OpenResolv supports exclusive mode with -x
|
||||||
|
cmd = exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||||
|
case typeResolvconf:
|
||||||
|
cmd = exec.Command(resolvconfCommand, "-a", r.ifaceName)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported resolvconf type: %v", r.implType)
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Stdin = &content
|
cmd.Stdin = &content
|
||||||
_, err := cmd.Output()
|
out, err := cmd.Output()
|
||||||
|
log.Tracef("resolvconf output: %s", out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -30,6 +31,8 @@ type IosDnsManager interface {
|
|||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
|
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
||||||
|
DeregisterHandler(domains []string, priority int)
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
@@ -45,15 +48,18 @@ type registeredHandlerMap map[string]handlerWithStop
|
|||||||
type DefaultServer struct {
|
type DefaultServer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
|
disableSys bool
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
|
handlerPriorities map[string]int
|
||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
|
handlerChain *HandlerChain
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
@@ -74,12 +80,20 @@ type handlerWithStop interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type muxUpdate struct {
|
type muxUpdate struct {
|
||||||
domain string
|
domain string
|
||||||
handler handlerWithStop
|
handler handlerWithStop
|
||||||
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
|
func NewDefaultServer(
|
||||||
|
ctx context.Context,
|
||||||
|
wgInterface WGIface,
|
||||||
|
customAddress string,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
stateManager *statemanager.Manager,
|
||||||
|
disableSys bool,
|
||||||
|
) (*DefaultServer, error) {
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if customAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
@@ -96,7 +110,7 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
|
|||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
|
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||||
@@ -107,9 +121,10 @@ func NewDefaultServerPermanentUpstream(
|
|||||||
config nbdns.Config,
|
config nbdns.Config,
|
||||||
listener listener.NetworkChangeListener,
|
listener listener.NetworkChangeListener,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
disableSys bool,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||||
ds.hostsDNSHolder.set(hostsDnsList)
|
ds.hostsDNSHolder.set(hostsDnsList)
|
||||||
ds.permanent = true
|
ds.permanent = true
|
||||||
ds.addHostRootZone()
|
ds.addHostRootZone()
|
||||||
@@ -126,19 +141,30 @@ func NewDefaultServerIos(
|
|||||||
wgInterface WGIface,
|
wgInterface WGIface,
|
||||||
iosDnsManager IosDnsManager,
|
iosDnsManager IosDnsManager,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
disableSys bool,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||||
ds.iosDnsManager = iosDnsManager
|
ds.iosDnsManager = iosDnsManager
|
||||||
return ds
|
return ds
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
func newDefaultServer(
|
||||||
|
ctx context.Context,
|
||||||
|
wgInterface WGIface,
|
||||||
|
dnsService service,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
stateManager *statemanager.Manager,
|
||||||
|
disableSys bool,
|
||||||
|
) *DefaultServer {
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
service: dnsService,
|
disableSys: disableSys,
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
service: dnsService,
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
handlerPriorities: make(map[string]int),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@@ -151,6 +177,51 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
s.registerHandler(domains, handler, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
log.Debugf("registering handler %s with priority %d", handler, priority)
|
||||||
|
|
||||||
|
for _, domain := range domains {
|
||||||
|
if domain == "" {
|
||||||
|
log.Warn("skipping empty domain")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
||||||
|
s.handlerPriorities[domain] = priority
|
||||||
|
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
s.deregisterHandler(domains, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||||
|
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
||||||
|
|
||||||
|
for _, domain := range domains {
|
||||||
|
s.handlerChain.RemoveHandler(domain, priority)
|
||||||
|
|
||||||
|
// Only deregister from service if no handlers remain
|
||||||
|
if !s.handlerChain.HasHandlers(domain) {
|
||||||
|
if domain == "" {
|
||||||
|
log.Warn("skipping empty domain")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize instantiate host manager and the dns service
|
// Initialize instantiate host manager and the dns service
|
||||||
func (s *DefaultServer) Initialize() (err error) {
|
func (s *DefaultServer) Initialize() (err error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
@@ -168,6 +239,16 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.stateManager.RegisterState(&ShutdownState{})
|
s.stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
|
// use noop host manager if requested or running in netstack mode.
|
||||||
|
// Netstack mode currently doesn't have a way to receive DNS requests.
|
||||||
|
// TODO: Use listener on localhost in netstack mode when running as root.
|
||||||
|
if s.disableSys || netstack.IsEnabled() {
|
||||||
|
log.Info("system DNS is disabled, not setting up host manager")
|
||||||
|
s.hostManager = &noopHostConfigurator{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
s.hostManager, err = s.initialize()
|
s.hostManager, err = s.initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("initialize: %w", err)
|
return fmt.Errorf("initialize: %w", err)
|
||||||
@@ -216,47 +297,47 @@ func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
|||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
// UpdateDNSServer processes an update received from the management service
|
||||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
select {
|
if s.ctx.Err() != nil {
|
||||||
case <-s.ctx.Done():
|
|
||||||
log.Infof("not updating DNS server as context is closed")
|
log.Infof("not updating DNS server as context is closed")
|
||||||
return s.ctx.Err()
|
return s.ctx.Err()
|
||||||
default:
|
}
|
||||||
if serial < s.updateSerial {
|
|
||||||
return fmt.Errorf("not applying dns update, error: "+
|
|
||||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
|
||||||
}
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
if s.hostManager == nil {
|
if serial < s.updateSerial {
|
||||||
return fmt.Errorf("dns service is not initialized yet")
|
return fmt.Errorf("not applying dns update, error: "+
|
||||||
}
|
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||||
|
}
|
||||||
|
|
||||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
s.mux.Lock()
|
||||||
ZeroNil: true,
|
defer s.mux.Unlock()
|
||||||
IgnoreZeroValue: true,
|
|
||||||
SlicesAsSets: true,
|
|
||||||
UseStringer: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.previousConfigHash == hash {
|
if s.hostManager == nil {
|
||||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
return fmt.Errorf("dns service is not initialized yet")
|
||||||
s.updateSerial = serial
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.applyConfiguration(update); err != nil {
|
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||||
return fmt.Errorf("apply configuration: %w", err)
|
ZeroNil: true,
|
||||||
}
|
IgnoreZeroValue: true,
|
||||||
|
SlicesAsSets: true,
|
||||||
|
UseStringer: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.previousConfigHash == hash {
|
||||||
|
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||||
s.updateSerial = serial
|
s.updateSerial = serial
|
||||||
s.previousConfigHash = hash
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.applyConfiguration(update); err != nil {
|
||||||
|
return fmt.Errorf("apply configuration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.updateSerial = serial
|
||||||
|
s.previousConfigHash = hash
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) SearchDomains() []string {
|
func (s *DefaultServer) SearchDomains() []string {
|
||||||
@@ -343,14 +424,14 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
for _, customZone := range customZones {
|
||||||
|
|
||||||
if len(customZone.Records) == 0 {
|
if len(customZone.Records) == 0 {
|
||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: customZone.Domain,
|
domain: customZone.Domain,
|
||||||
handler: s.localResolver,
|
handler: s.localResolver,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
for _, record := range customZone.Records {
|
||||||
@@ -412,8 +493,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
priority: PriorityDefault,
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -429,8 +511,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
}
|
}
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: domain,
|
domain: domain,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -440,12 +523,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
|
handlersByPriority := make(map[string]int)
|
||||||
|
|
||||||
var isContainRootUpdate bool
|
var isContainRootUpdate bool
|
||||||
|
|
||||||
|
// First register new handlers
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.service.RegisterMux(update.domain, update.handler)
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
muxUpdateMap[update.domain] = update.handler
|
muxUpdateMap[update.domain] = update.handler
|
||||||
|
handlersByPriority[update.domain] = update.priority
|
||||||
|
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
}
|
}
|
||||||
@@ -455,6 +542,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Then deregister old handlers not in the update
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
for key, existingHandler := range s.dnsMuxMap {
|
||||||
_, found := muxUpdateMap[key]
|
_, found := muxUpdateMap[key]
|
||||||
if !found {
|
if !found {
|
||||||
@@ -463,12 +551,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
} else {
|
} else {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
s.service.DeregisterMux(key)
|
// Deregister with the priority that was used to register
|
||||||
|
if oldPriority, ok := s.handlerPriorities[key]; ok {
|
||||||
|
s.deregisterHandler([]string{key}, oldPriority)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
|
s.handlerPriorities = handlersByPriority
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
@@ -517,13 +609,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
removeIndex[nbdns.RootZone] = -1
|
removeIndex[nbdns.RootZone] = -1
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
s.service.DeregisterMux(nbdns.RootZone)
|
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
for i, item := range s.currentConfig.Domains {
|
||||||
if _, found := removeIndex[item.Domain]; found {
|
if _, found := removeIndex[item.Domain]; found {
|
||||||
s.currentConfig.Domains[i].Disabled = true
|
s.currentConfig.Domains[i].Disabled = true
|
||||||
s.service.DeregisterMux(item.Domain)
|
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
|
||||||
removeIndex[item.Domain] = i
|
removeIndex[item.Domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -554,7 +646,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.Domains[i].Disabled = false
|
s.currentConfig.Domains[i].Disabled = false
|
||||||
s.service.RegisterMux(domain, handler)
|
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@@ -562,10 +654,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||||
}
|
}
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
if s.hostManager != nil {
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||||
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.updateNSState(nsGroup, nil, true)
|
s.updateNSState(nsGroup, nil, true)
|
||||||
@@ -593,7 +688,8 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
}
|
}
|
||||||
handler.deactivate = func(error) {}
|
handler.deactivate = func(error) {}
|
||||||
handler.reactivate = func() {}
|
handler.reactivate = func() {}
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
|
||||||
|
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
@@ -292,7 +294,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -401,7 +403,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
return
|
return
|
||||||
@@ -496,7 +498,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
}
|
}
|
||||||
@@ -512,7 +514,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
|
||||||
|
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
@@ -560,7 +562,9 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
hostManager: hostManager,
|
handlerChain: NewHandlerChain(),
|
||||||
|
handlerPriorities: make(map[string]int),
|
||||||
|
hostManager: hostManager,
|
||||||
currentConfig: HostDNSConfig{
|
currentConfig: HostDNSConfig{
|
||||||
Domains: []DomainConfig{
|
Domains: []DomainConfig{
|
||||||
{false, "domain0", false},
|
{false, "domain0", false},
|
||||||
@@ -629,7 +633,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|||||||
|
|
||||||
var dnsList []string
|
var dnsList []string
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@@ -653,7 +657,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@@ -745,7 +749,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@@ -872,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MockHandler implements dns.Handler interface for testing
|
||||||
|
type MockHandler struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
m.Called(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockSubdomainHandler struct {
|
||||||
|
MockHandler
|
||||||
|
Subdomains bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSubdomainHandler) MatchSubdomains() bool {
|
||||||
|
return m.Subdomains
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||||
|
chain := NewHandlerChain()
|
||||||
|
|
||||||
|
dnsRouteHandler := &MockHandler{}
|
||||||
|
upstreamHandler := &MockSubdomainHandler{
|
||||||
|
Subdomains: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
expectedHandler dns.Handler
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain with dns route handler",
|
||||||
|
query: "example.com.",
|
||||||
|
expectedHandler: dnsRouteHandler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain should use upstream handler",
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedHandler: upstreamHandler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep subdomain should use upstream handler",
|
||||||
|
query: "deep.sub.example.com.",
|
||||||
|
expectedHandler: upstreamHandler,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tc.query, dns.TypeA)
|
||||||
|
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||||
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
|
mh.AssertExpectations(t)
|
||||||
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||||
|
mh.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset mocks
|
||||||
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
|
mh.ExpectedCalls = nil
|
||||||
|
mh.Calls = nil
|
||||||
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||||
|
mh.ExpectedCalls = nil
|
||||||
|
mh.Calls = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
log.Debugf("registering dns handler for pattern: %s", pattern)
|
||||||
s.dnsMux.Handle(pattern, handler)
|
s.dnsMux.Handle(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -66,6 +66,15 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the upstream resolver
|
||||||
|
func (u *upstreamResolverBase) String() string {
|
||||||
|
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) stop() {
|
func (u *upstreamResolverBase) stop() {
|
||||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||||
u.cancel()
|
u.cancel()
|
||||||
|
|||||||
157
client/internal/dnsfwd/forwarder.go
Normal file
157
client/internal/dnsfwd/forwarder.go
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
|
|
||||||
|
type DNSForwarder struct {
|
||||||
|
listenAddress string
|
||||||
|
ttl uint32
|
||||||
|
domains []string
|
||||||
|
|
||||||
|
dnsServer *dns.Server
|
||||||
|
mux *dns.ServeMux
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
|
||||||
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||||
|
return &DNSForwarder{
|
||||||
|
listenAddress: listenAddress,
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) Listen(domains []string) error {
|
||||||
|
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
dnsServer := &dns.Server{
|
||||||
|
Addr: f.listenAddress,
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
f.dnsServer = dnsServer
|
||||||
|
f.mux = mux
|
||||||
|
|
||||||
|
f.UpdateDomains(domains)
|
||||||
|
|
||||||
|
return dnsServer.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
||||||
|
log.Debugf("Updating domains from %v to %v", f.domains, domains)
|
||||||
|
|
||||||
|
for _, d := range f.domains {
|
||||||
|
f.mux.HandleRemove(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
newDomains := filterDomains(domains)
|
||||||
|
for _, d := range newDomains {
|
||||||
|
f.mux.HandleFunc(d, f.handleDNSQuery)
|
||||||
|
}
|
||||||
|
f.domains = newDomains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
|
if f.dnsServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return f.dnsServer.ShutdownContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
if len(query.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||||
|
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
|
||||||
|
|
||||||
|
question := query.Question[0]
|
||||||
|
domain := question.Name
|
||||||
|
|
||||||
|
resp := query.SetReply(query)
|
||||||
|
|
||||||
|
ips, err := net.LookupIP(domain)
|
||||||
|
if err != nil {
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case errors.As(err, &dnsErr):
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
if dnsErr.IsNotFound {
|
||||||
|
// Pass through NXDOMAIN
|
||||||
|
resp.Rcode = dns.RcodeNameError
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsErr.Server != "" {
|
||||||
|
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
||||||
|
} else {
|
||||||
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
var respRecord dns.RR
|
||||||
|
if ip.To4() == nil {
|
||||||
|
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||||
|
rr := dns.AAAA{
|
||||||
|
AAAA: ip,
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: domain,
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: f.ttl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
respRecord = &rr
|
||||||
|
} else {
|
||||||
|
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||||
|
rr := dns.A{
|
||||||
|
A: ip,
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: domain,
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: f.ttl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
respRecord = &rr
|
||||||
|
}
|
||||||
|
resp.Answer = append(resp.Answer, respRecord)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterDomains returns a list of normalized domains
|
||||||
|
func filterDomains(domains []string) []string {
|
||||||
|
newDomains := make([]string, 0, len(domains))
|
||||||
|
for _, d := range domains {
|
||||||
|
if d == "" {
|
||||||
|
log.Warn("empty domain in DNS forwarder")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newDomains = append(newDomains, nbdns.NormalizeZone(d))
|
||||||
|
}
|
||||||
|
return newDomains
|
||||||
|
}
|
||||||
111
client/internal/dnsfwd/manager.go
Normal file
111
client/internal/dnsfwd/manager.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||||
|
ListenPort = 5353
|
||||||
|
dnsTTL = 60 //seconds
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
firewall firewall.Manager
|
||||||
|
|
||||||
|
fwRules []firewall.Rule
|
||||||
|
dnsForwarder *DNSForwarder
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(fw firewall.Manager) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
firewall: fw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(domains []string) error {
|
||||||
|
log.Infof("starting DNS forwarder")
|
||||||
|
if m.dnsForwarder != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.allowDNSFirewall(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
|
||||||
|
go func() {
|
||||||
|
if err := m.dnsForwarder.Listen(domains); err != nil {
|
||||||
|
// todo handle close error if it is exists
|
||||||
|
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) UpdateDomains(domains []string) {
|
||||||
|
if m.dnsForwarder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder.UpdateDomains(domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Stop(ctx context.Context) error {
|
||||||
|
if m.dnsForwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var mErr *multierror.Error
|
||||||
|
if err := m.dropDNSFirewall(); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.dnsForwarder.Close(ctx); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder = nil
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) allowDNSFirewall() error {
|
||||||
|
dport := &firewall.Port{
|
||||||
|
IsRange: false,
|
||||||
|
Values: []uint16{ListenPort},
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.firewall == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.fwRules = dnsRules
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) dropDNSFirewall() error {
|
||||||
|
var mErr *multierror.Error
|
||||||
|
for _, rule := range h.fwRules {
|
||||||
|
if err := h.firewall.DeletePeerRule(rule); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.fwRules = nil
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -17,23 +16,28 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/pion/ice/v3"
|
"github.com/pion/ice/v3"
|
||||||
"github.com/pion/stun/v2"
|
"github.com/pion/stun/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
@@ -107,6 +111,13 @@ type EngineConfig struct {
|
|||||||
ServerSSHAllowed bool
|
ServerSSHAllowed bool
|
||||||
|
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
|
|
||||||
|
DisableClientRoutes bool
|
||||||
|
DisableServerRoutes bool
|
||||||
|
DisableDNS bool
|
||||||
|
DisableFirewall bool
|
||||||
|
|
||||||
|
BlockLANAccess bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@@ -117,7 +128,7 @@ type Engine struct {
|
|||||||
// mgmClient is a Management Service client
|
// mgmClient is a Management Service client
|
||||||
mgmClient mgm.Client
|
mgmClient mgm.Client
|
||||||
// peerConns is a map that holds all the peers that are known to this peer
|
// peerConns is a map that holds all the peers that are known to this peer
|
||||||
peerConns map[string]*peer.Conn
|
peerStore *peerstore.Store
|
||||||
|
|
||||||
beforePeerHook nbnet.AddHookFunc
|
beforePeerHook nbnet.AddHookFunc
|
||||||
afterPeerHook nbnet.RemoveHookFunc
|
afterPeerHook nbnet.RemoveHookFunc
|
||||||
@@ -137,10 +148,6 @@ type Engine struct {
|
|||||||
TURNs []*stun.URI
|
TURNs []*stun.URI
|
||||||
stunTurn atomic.Value
|
stunTurn atomic.Value
|
||||||
|
|
||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
|
||||||
clientRoutes route.HAMap
|
|
||||||
clientRoutesMu sync.RWMutex
|
|
||||||
|
|
||||||
clientCtx context.Context
|
clientCtx context.Context
|
||||||
clientCancel context.CancelFunc
|
clientCancel context.CancelFunc
|
||||||
|
|
||||||
@@ -161,9 +168,10 @@ type Engine struct {
|
|||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
firewall manager.Manager
|
firewall manager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
acl acl.Manager
|
acl acl.Manager
|
||||||
|
dnsForwardMgr *dnsfwd.Manager
|
||||||
|
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
|
||||||
@@ -234,7 +242,7 @@ func NewEngineWithProbes(
|
|||||||
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
||||||
mgmClient: mgmClient,
|
mgmClient: mgmClient,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
peerConns: make(map[string]*peer.Conn),
|
peerStore: peerstore.NewConnStore(),
|
||||||
syncMsgMux: &sync.Mutex{},
|
syncMsgMux: &sync.Mutex{},
|
||||||
config: config,
|
config: config,
|
||||||
mobileDep: mobileDep,
|
mobileDep: mobileDep,
|
||||||
@@ -287,6 +295,13 @@ func (e *Engine) Stop() error {
|
|||||||
e.routeManager.Stop(e.stateManager)
|
e.routeManager.Stop(e.stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.dnsForwardMgr != nil {
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
if e.srWatcher != nil {
|
if e.srWatcher != nil {
|
||||||
e.srWatcher.Close()
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
@@ -300,10 +315,6 @@ func (e *Engine) Stop() error {
|
|||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.clientRoutesMu.Lock()
|
|
||||||
e.clientRoutes = nil
|
|
||||||
e.clientRoutesMu.Unlock()
|
|
||||||
|
|
||||||
if e.cancel != nil {
|
if e.cancel != nil {
|
||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
@@ -373,16 +384,20 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(
|
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||||
e.ctx,
|
Context: e.ctx,
|
||||||
e.config.WgPrivateKey.PublicKey().String(),
|
PublicKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||||
e.config.DNSRouteInterval,
|
DNSRouteInterval: e.config.DNSRouteInterval,
|
||||||
e.wgInterface,
|
WGInterface: e.wgInterface,
|
||||||
e.statusRecorder,
|
StatusRecorder: e.statusRecorder,
|
||||||
e.relayManager,
|
RelayManager: e.relayManager,
|
||||||
initialRoutes,
|
InitialRoutes: initialRoutes,
|
||||||
e.stateManager,
|
StateManager: e.stateManager,
|
||||||
)
|
DNSServer: dnsServer,
|
||||||
|
PeerStore: e.peerStore,
|
||||||
|
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||||
|
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||||
|
})
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to initialize route manager: %s", err)
|
log.Errorf("Failed to initialize route manager: %s", err)
|
||||||
@@ -400,17 +415,8 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("create wg interface: %w", err)
|
return fmt.Errorf("create wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
if err := e.createFirewall(); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.firewall != nil && e.firewall.IsServerRouteSupported() {
|
|
||||||
err = e.routeManager.EnableServerRouter(e.firewall)
|
|
||||||
if err != nil {
|
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("enable server router: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.udpMux, err = e.wgInterface.Up()
|
e.udpMux, err = e.wgInterface.Up()
|
||||||
@@ -452,6 +458,93 @@ func (e *Engine) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) createFirewall() error {
|
||||||
|
if e.config.DisableFirewall {
|
||||||
|
log.Infof("firewall is disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||||
|
if err != nil || e.firewall == nil {
|
||||||
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.initFirewall(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) initFirewall() error {
|
||||||
|
if e.firewall.IsServerRouteSupported() {
|
||||||
|
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||||
|
e.close()
|
||||||
|
return fmt.Errorf("enable server router: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.BlockLANAccess {
|
||||||
|
e.blockLanAccess()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.rpManager == nil || !e.config.RosenpassEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rosenpassPort := e.rpManager.GetAddress().Port
|
||||||
|
port := manager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||||
|
|
||||||
|
// this rule is static and will be torn down on engine down by the firewall manager
|
||||||
|
if _, err := e.firewall.AddPeerFiltering(
|
||||||
|
net.IP{0, 0, 0, 0},
|
||||||
|
manager.ProtocolUDP,
|
||||||
|
nil,
|
||||||
|
&port,
|
||||||
|
manager.ActionAccept,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
); err != nil {
|
||||||
|
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) blockLanAccess() {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// TODO: keep this updated
|
||||||
|
toBlock, err := getInterfacePrefixes()
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("get local addresses: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("blocking route LAN access for networks: %v", toBlock)
|
||||||
|
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
|
for _, network := range toBlock {
|
||||||
|
if _, err := e.firewall.AddRouteFiltering(
|
||||||
|
[]netip.Prefix{v4},
|
||||||
|
network,
|
||||||
|
manager.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
manager.ActionDrop,
|
||||||
|
); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if merr != nil {
|
||||||
|
log.Warnf("encountered errors blocking IPs to block LAN access: %v", nberrors.FormatErrorOrNil(merr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
||||||
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
||||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
@@ -460,8 +553,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
var modified []*mgmProto.RemotePeerConfig
|
var modified []*mgmProto.RemotePeerConfig
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
peerPubKey := p.GetWgPubKey()
|
peerPubKey := p.GetWgPubKey()
|
||||||
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
|
||||||
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
|
if allowedIPs != strings.Join(p.AllowedIps, ",") {
|
||||||
modified = append(modified, p)
|
modified = append(modified, p)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -492,17 +585,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
||||||
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
||||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
currentPeers := make([]string, 0, len(e.peerConns))
|
|
||||||
for p := range e.peerConns {
|
|
||||||
currentPeers = append(currentPeers, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
newPeers := make([]string, 0, len(peersUpdate))
|
newPeers := make([]string, 0, len(peersUpdate))
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
newPeers = append(newPeers, p.GetWgPubKey())
|
newPeers = append(newPeers, p.GetWgPubKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
toRemove := util.SliceDiff(currentPeers, newPeers)
|
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
||||||
|
|
||||||
for _, p := range toRemove {
|
for _, p := range toRemove {
|
||||||
err := e.removePeer(p)
|
err := e.removePeer(p)
|
||||||
@@ -516,7 +604,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
|
|
||||||
func (e *Engine) removeAllPeers() error {
|
func (e *Engine) removeAllPeers() error {
|
||||||
log.Debugf("removing all peer connections")
|
log.Debugf("removing all peer connections")
|
||||||
for p := range e.peerConns {
|
for _, p := range e.peerStore.PeersPubKey() {
|
||||||
err := e.removePeer(p)
|
err := e.removePeer(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -540,9 +628,8 @@ func (e *Engine) removePeer(peerKey string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn, exists := e.peerConns[peerKey]
|
conn, exists := e.peerStore.Remove(peerKey)
|
||||||
if exists {
|
if exists {
|
||||||
delete(e.peerConns, peerKey)
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -629,6 +716,15 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
info = system.GetInfo(e.ctx)
|
info = system.GetInfo(e.ctx)
|
||||||
}
|
}
|
||||||
|
info.SetFlags(
|
||||||
|
e.config.RosenpassEnabled,
|
||||||
|
e.config.RosenpassPermissive,
|
||||||
|
&e.config.ServerSSHAllowed,
|
||||||
|
e.config.DisableClientRoutes,
|
||||||
|
e.config.DisableServerRoutes,
|
||||||
|
e.config.DisableDNS,
|
||||||
|
e.config.DisableFirewall,
|
||||||
|
)
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||||
log.Errorf("could not sync meta: error %s", err)
|
log.Errorf("could not sync meta: error %s", err)
|
||||||
@@ -649,18 +745,22 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
} else {
|
} else {
|
||||||
|
|
||||||
if sshConf.GetSshEnabled() {
|
if sshConf.GetSshEnabled() {
|
||||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
if runtime.GOOS == "windows" {
|
||||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// start SSH server if it wasn't running
|
// start SSH server if it wasn't running
|
||||||
if isNil(e.sshServer) {
|
if isNil(e.sshServer) {
|
||||||
|
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||||
|
}
|
||||||
// nil sshServer means it has not yet been started
|
// nil sshServer means it has not yet been started
|
||||||
var err error
|
var err error
|
||||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
|
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||||
fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("create ssh server: %w", err)
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
@@ -709,16 +809,17 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
if conf.GetSshConfig() != nil {
|
if conf.GetSshConfig() != nil {
|
||||||
err := e.updateSSH(conf.GetSshConfig())
|
err := e.updateSSH(conf.GetSshConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed handling SSH server setup %v", err)
|
log.Warnf("failed handling SSH server setup: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
IP: e.config.WgAddr,
|
state.IP = e.config.WgAddr
|
||||||
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
||||||
FQDN: conf.GetFqdn(),
|
state.FQDN = conf.GetFqdn()
|
||||||
})
|
|
||||||
|
e.statusRecorder.UpdateLocalPeerState(state)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -732,6 +833,15 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
info = system.GetInfo(e.ctx)
|
info = system.GetInfo(e.ctx)
|
||||||
}
|
}
|
||||||
|
info.SetFlags(
|
||||||
|
e.config.RosenpassEnabled,
|
||||||
|
e.config.RosenpassPermissive,
|
||||||
|
&e.config.ServerSSHAllowed,
|
||||||
|
e.config.DisableClientRoutes,
|
||||||
|
e.config.DisableServerRoutes,
|
||||||
|
e.config.DisableDNS,
|
||||||
|
e.config.DisableFirewall,
|
||||||
|
)
|
||||||
|
|
||||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
// err = e.mgmClient.Sync(info, e.handleSync)
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||||
@@ -786,7 +896,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||||
|
|
||||||
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
||||||
if networkMap.GetPeerConfig() != nil {
|
if networkMap.GetPeerConfig() != nil {
|
||||||
err := e.updateConfig(networkMap.GetPeerConfig())
|
err := e.updateConfig(networkMap.GetPeerConfig())
|
||||||
@@ -806,20 +915,16 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.acl.ApplyFiltering(networkMap)
|
e.acl.ApplyFiltering(networkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
protoRoutes := networkMap.GetRoutes()
|
// DNS forwarder
|
||||||
if protoRoutes == nil {
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
protoRoutes = []*mgmProto.Route{}
|
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
||||||
}
|
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains)
|
||||||
|
|
||||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
routes := toRoutes(networkMap.GetRoutes())
|
||||||
if err != nil {
|
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.clientRoutesMu.Lock()
|
|
||||||
e.clientRoutes = clientRoutes
|
|
||||||
e.clientRoutesMu.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
@@ -867,8 +972,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -881,7 +985,18 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
||||||
|
if networkMap.PeerConfig != nil {
|
||||||
|
return networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||||
|
if protoRoutes == nil {
|
||||||
|
protoRoutes = []*mgmProto.Route{}
|
||||||
|
}
|
||||||
|
|
||||||
routes := make([]*route.Route, 0)
|
routes := make([]*route.Route, 0)
|
||||||
for _, protoRoute := range protoRoutes {
|
for _, protoRoute := range protoRoutes {
|
||||||
var prefix netip.Prefix
|
var prefix netip.Prefix
|
||||||
@@ -892,6 +1007,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
convertedRoute := &route.Route{
|
convertedRoute := &route.Route{
|
||||||
ID: route.ID(protoRoute.ID),
|
ID: route.ID(protoRoute.ID),
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
@@ -908,6 +1024,23 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
|
||||||
|
if protoRoutes == nil {
|
||||||
|
protoRoutes = []*mgmProto.Route{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var dnsRoutes []string
|
||||||
|
for _, protoRoute := range protoRoutes {
|
||||||
|
if len(protoRoute.Domains) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if protoRoute.Peer == myPubKey {
|
||||||
|
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dnsRoutes
|
||||||
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||||
dnsUpdate := nbdns.Config{
|
dnsUpdate := nbdns.Config{
|
||||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||||
@@ -982,12 +1115,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||||
peerKey := peerConfig.GetWgPubKey()
|
peerKey := peerConfig.GetWgPubKey()
|
||||||
peerIPs := peerConfig.GetAllowedIps()
|
peerIPs := peerConfig.GetAllowedIps()
|
||||||
if _, ok := e.peerConns[peerKey]; !ok {
|
if _, ok := e.peerStore.PeerConn(peerKey); !ok {
|
||||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create peer connection: %w", err)
|
return fmt.Errorf("create peer connection: %w", err)
|
||||||
}
|
}
|
||||||
e.peerConns[peerKey] = conn
|
|
||||||
|
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
||||||
|
conn.Close()
|
||||||
|
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||||
@@ -1076,8 +1213,8 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
conn := e.peerConns[msg.Key]
|
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||||
if conn == nil {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1135,7 +1272,7 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
|
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
|
||||||
case sProto.Body_MODE:
|
case sProto.Body_MODE:
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1235,6 +1372,16 @@ func (e *Engine) close() {
|
|||||||
|
|
||||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||||
info := system.GetInfo(e.ctx)
|
info := system.GetInfo(e.ctx)
|
||||||
|
info.SetFlags(
|
||||||
|
e.config.RosenpassEnabled,
|
||||||
|
e.config.RosenpassPermissive,
|
||||||
|
&e.config.ServerSSHAllowed,
|
||||||
|
e.config.DisableClientRoutes,
|
||||||
|
e.config.DisableServerRoutes,
|
||||||
|
e.config.DisableDNS,
|
||||||
|
e.config.DisableFirewall,
|
||||||
|
)
|
||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -1293,6 +1440,7 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
if e.dnsServer != nil {
|
if e.dnsServer != nil {
|
||||||
return nil, e.dnsServer, nil
|
return nil, e.dnsServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "android":
|
case "android":
|
||||||
routes, dnsConfig, err := e.readInitialSettings()
|
routes, dnsConfig, err := e.readInitialSettings()
|
||||||
@@ -1306,14 +1454,17 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
*dnsConfig,
|
*dnsConfig,
|
||||||
e.mobileDep.NetworkChangeListener,
|
e.mobileDep.NetworkChangeListener,
|
||||||
e.statusRecorder,
|
e.statusRecorder,
|
||||||
|
e.config.DisableDNS,
|
||||||
)
|
)
|
||||||
go e.mobileDep.DnsReadyListener.OnReady()
|
go e.mobileDep.DnsReadyListener.OnReady()
|
||||||
return routes, dnsServer, nil
|
return routes, dnsServer, nil
|
||||||
|
|
||||||
case "ios":
|
case "ios":
|
||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||||
return nil, dnsServer, nil
|
return nil, dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -1322,26 +1473,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the current routes from the route map
|
|
||||||
func (e *Engine) GetClientRoutes() route.HAMap {
|
|
||||||
e.clientRoutesMu.RLock()
|
|
||||||
defer e.clientRoutesMu.RUnlock()
|
|
||||||
|
|
||||||
return maps.Clone(e.clientRoutes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
|
||||||
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
|
||||||
e.clientRoutesMu.RLock()
|
|
||||||
defer e.clientRoutesMu.RUnlock()
|
|
||||||
|
|
||||||
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
|
|
||||||
for id, v := range e.clientRoutes {
|
|
||||||
routes[id.NetID()] = v
|
|
||||||
}
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRouteManager returns the route manager
|
// GetRouteManager returns the route manager
|
||||||
func (e *Engine) GetRouteManager() routemanager.Manager {
|
func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||||
return e.routeManager
|
return e.routeManager
|
||||||
@@ -1426,9 +1557,8 @@ func (e *Engine) receiveProbeEvents() {
|
|||||||
go e.probes.WgProbe.Receive(e.ctx, func() bool {
|
go e.probes.WgProbe.Receive(e.ctx, func() bool {
|
||||||
log.Debug("received wg probe request")
|
log.Debug("received wg probe request")
|
||||||
|
|
||||||
for _, peer := range e.peerConns {
|
for _, key := range e.peerStore.PeersPubKey() {
|
||||||
key := peer.GetKey()
|
wgStats, err := e.wgInterface.GetStats(key)
|
||||||
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
||||||
}
|
}
|
||||||
@@ -1505,7 +1635,7 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
|
|
||||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||||
var vpnRoutes []netip.Prefix
|
var vpnRoutes []netip.Prefix
|
||||||
for _, routes := range e.GetClientRoutes() {
|
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||||
if len(routes) > 0 && routes[0] != nil {
|
if len(routes) > 0 && routes[0] != nil {
|
||||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||||
}
|
}
|
||||||
@@ -1563,7 +1693,7 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a deep copy to avoid external modifications
|
log.Debugf("Retrieving latest network map with size %d bytes", proto.Size(e.latestNetworkMap))
|
||||||
nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap)
|
nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
||||||
@@ -1573,6 +1703,40 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
|||||||
return nm, nil
|
return nm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||||
|
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||||
|
if !enabled {
|
||||||
|
if e.dnsForwardMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(domains) > 0 {
|
||||||
|
log.Infof("enable domain router service for domains: %v", domains)
|
||||||
|
if e.dnsForwardMgr == nil {
|
||||||
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
|
||||||
|
|
||||||
|
if err := e.dnsForwardMgr.Start(domains); err != nil {
|
||||||
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Infof("update domain router service for domains: %v", domains)
|
||||||
|
e.dnsForwardMgr.UpdateDomains(domains)
|
||||||
|
}
|
||||||
|
} else if e.dnsForwardMgr != nil {
|
||||||
|
log.Infof("disable domain router service")
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
for _, check := range checks {
|
for _, check := range checks {
|
||||||
@@ -1590,3 +1754,45 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
|||||||
return slices.Equal(checks.Files, oChecks.Files)
|
return slices.Equal(checks.Files, oChecks.Files)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getInterfacePrefixes() ([]netip.Prefix, error) {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get interfaces: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var prefixes []netip.Prefix
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("get addresses for interface %s: %w", iface.Name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ipNet, ok := addr.(*net.IPNet)
|
||||||
|
if !ok {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("cast address to IPNet: %v", addr))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addr, ok := netip.AddrFromSlice(ipNet.IP)
|
||||||
|
if !ok {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("cast IPNet to netip.Addr: %v", ipNet.IP))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ones, _ := ipNet.Mask.Size()
|
||||||
|
prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked()
|
||||||
|
ip := prefix.Addr()
|
||||||
|
|
||||||
|
// TODO: add IPv6
|
||||||
|
if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixes = append(prefixes, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefixes, nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ import (
|
|||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -69,8 +71,7 @@ func TestMain(m *testing.M) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_SSH(t *testing.T) {
|
func TestEngine_SSH(t *testing.T) {
|
||||||
// todo resolve test execution on freebsd
|
if runtime.GOOS == "windows" {
|
||||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
|
||||||
t.Skip("skipping TestEngine_SSH")
|
t.Skip("skipping TestEngine_SSH")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,7 +252,14 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
engine.wgInterface = wgIface
|
engine.wgInterface = wgIface
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
|
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||||
|
Context: ctx,
|
||||||
|
PublicKey: key.PublicKey().String(),
|
||||||
|
DNSRouteInterval: time.Minute,
|
||||||
|
WGInterface: engine.wgInterface,
|
||||||
|
StatusRecorder: engine.statusRecorder,
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
})
|
||||||
_, _, err = engine.routeManager.Init()
|
_, _, err = engine.routeManager.Init()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@@ -391,8 +399,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(engine.peerConns) != c.expectedLen {
|
if len(engine.peerStore.PeersPubKey()) != c.expectedLen {
|
||||||
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns))
|
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if engine.networkSerial != c.expectedSerial {
|
if engine.networkSerial != c.expectedSerial {
|
||||||
@@ -400,7 +408,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range c.expectedPeers {
|
for _, p := range c.expectedPeers {
|
||||||
conn, ok := engine.peerConns[p.GetWgPubKey()]
|
conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey())
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
||||||
}
|
}
|
||||||
@@ -625,10 +633,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
}{}
|
}{}
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
input.inputSerial = updateSerial
|
input.inputSerial = updateSerial
|
||||||
input.inputRoutes = newRoutes
|
input.inputRoutes = newRoutes
|
||||||
return nil, nil, testCase.inputErr
|
return testCase.inputErr
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -801,8 +809,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
return nil, nil, nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1196,7 +1204,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -1218,7 +1226,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -1237,7 +1245,8 @@ func getConnectedPeers(e *Engine) int {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
i := 0
|
i := 0
|
||||||
for _, conn := range e.peerConns {
|
for _, id := range e.peerStore.PeersPubKey() {
|
||||||
|
conn, _ := e.peerStore.PeerConn(id)
|
||||||
if conn.Status() == peer.StatusConnected {
|
if conn.Status() == peer.StatusConnected {
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
@@ -1249,5 +1258,5 @@ func getPeers(e *Engine) int {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
return len(e.peerConns)
|
return len(e.peerStore.PeersPubKey())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// IsLoginRequired check that the server is support SSO or not
|
// IsLoginRequired check that the server is support SSO or not
|
||||||
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) {
|
func IsLoginRequired(ctx context.Context, config *Config) (bool, error) {
|
||||||
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
|
mgmURL := config.ManagementURL
|
||||||
|
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -33,12 +34,12 @@ func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, ss
|
|||||||
}()
|
}()
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(sshKey))
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey)
|
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||||
if isLoginNeeded(err) {
|
if isLoginNeeded(err) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@@ -67,10 +68,10 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||||
if serverKey != nil && isRegistrationNeeded(err) {
|
if serverKey != nil && isRegistrationNeeded(err) {
|
||||||
log.Debugf("peer registration required")
|
log.Debugf("peer registration required")
|
||||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,7 +100,7 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
|
|||||||
return mgmClient, err
|
return mgmClient, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) {
|
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, error) {
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
serverKey, err := mgmClient.GetServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
@@ -107,13 +108,22 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
|
|||||||
}
|
}
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
|
sysInfo.SetFlags(
|
||||||
|
config.RosenpassEnabled,
|
||||||
|
config.RosenpassPermissive,
|
||||||
|
config.ServerSSHAllowed,
|
||||||
|
config.DisableClientRoutes,
|
||||||
|
config.DisableServerRoutes,
|
||||||
|
config.DisableDNS,
|
||||||
|
config.DisableFirewall,
|
||||||
|
)
|
||||||
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
||||||
return serverKey, err
|
return serverKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||||
validSetupKey, err := uuid.Parse(setupKey)
|
validSetupKey, err := uuid.Parse(setupKey)
|
||||||
if err != nil && jwtToken == "" {
|
if err != nil && jwtToken == "" {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||||
@@ -121,6 +131,15 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
|||||||
|
|
||||||
log.Debugf("sending peer registration request to Management Service")
|
log.Debugf("sending peer registration request to Management Service")
|
||||||
info := system.GetInfo(ctx)
|
info := system.GetInfo(ctx)
|
||||||
|
info.SetFlags(
|
||||||
|
config.RosenpassEnabled,
|
||||||
|
config.RosenpassPermissive,
|
||||||
|
config.ServerSSHAllowed,
|
||||||
|
config.DisableClientRoutes,
|
||||||
|
config.DisableServerRoutes,
|
||||||
|
config.DisableDNS,
|
||||||
|
config.DisableFirewall,
|
||||||
|
)
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey)
|
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
|
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
|
||||||
|
|||||||
@@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
|||||||
conn.wgProxyRelay = proxy
|
conn.wgProxyRelay = proxy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowedIP returns the allowed IP of the remote peer
|
||||||
|
func (conn *Conn) AllowedIP() net.IP {
|
||||||
|
return conn.allowedIP
|
||||||
|
}
|
||||||
|
|
||||||
func isController(config ConnConfig) bool {
|
func isController(config ConnConfig) bool {
|
||||||
return config.LocalKey > config.Key
|
return config.LocalKey > config.Key
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ResolvedDomainInfo struct {
|
||||||
|
Prefixes []netip.Prefix
|
||||||
|
ParentDomain domain.Domain
|
||||||
|
}
|
||||||
|
|
||||||
// State contains the latest state of a peer
|
// State contains the latest state of a peer
|
||||||
type State struct {
|
type State struct {
|
||||||
Mux *sync.RWMutex
|
Mux *sync.RWMutex
|
||||||
@@ -79,6 +84,12 @@ type LocalPeerState struct {
|
|||||||
Routes map[string]struct{}
|
Routes map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone returns a copy of the LocalPeerState
|
||||||
|
func (l LocalPeerState) Clone() LocalPeerState {
|
||||||
|
l.Routes = maps.Clone(l.Routes)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
// SignalState contains the latest state of a signal connection
|
// SignalState contains the latest state of a signal connection
|
||||||
type SignalState struct {
|
type SignalState struct {
|
||||||
URL string
|
URL string
|
||||||
@@ -138,7 +149,7 @@ type Status struct {
|
|||||||
rosenpassEnabled bool
|
rosenpassEnabled bool
|
||||||
rosenpassPermissive bool
|
rosenpassPermissive bool
|
||||||
nsGroupStates []NSGroupState
|
nsGroupStates []NSGroupState
|
||||||
resolvedDomainsStates map[domain.Domain][]netip.Prefix
|
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
||||||
|
|
||||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||||
@@ -156,7 +167,7 @@ func NewRecorder(mgmAddress string) *Status {
|
|||||||
offlinePeers: make([]State, 0),
|
offlinePeers: make([]State, 0),
|
||||||
notifier: newNotifier(),
|
notifier: newNotifier(),
|
||||||
mgmAddress: mgmAddress,
|
mgmAddress: mgmAddress,
|
||||||
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
|
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -496,7 +507,7 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
|||||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
return d.localPeer
|
return d.localPeer.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalPeerState updates local peer status
|
// UpdateLocalPeerState updates local peer status
|
||||||
@@ -591,16 +602,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
|||||||
d.nsGroupStates = dnsStates
|
d.nsGroupStates = dnsStates
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
|
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
d.resolvedDomainsStates[domain] = prefixes
|
|
||||||
|
// Store both the original domain pattern and resolved domain
|
||||||
|
d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{
|
||||||
|
Prefixes: prefixes,
|
||||||
|
ParentDomain: originalDomain,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
delete(d.resolvedDomainsStates, domain)
|
|
||||||
|
// Remove all entries that have this domain as their parent
|
||||||
|
for k, v := range d.resolvedDomainsStates {
|
||||||
|
if v.ParentDomain == domain {
|
||||||
|
delete(d.resolvedDomainsStates, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetRosenpassState() RosenpassState {
|
func (d *Status) GetRosenpassState() RosenpassState {
|
||||||
@@ -702,7 +724,7 @@ func (d *Status) GetDNSStates() []NSGroupState {
|
|||||||
return d.nsGroupStates
|
return d.nsGroupStates
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
return maps.Clone(d.resolvedDomainsStates)
|
return maps.Clone(d.resolvedDomainsStates)
|
||||||
|
|||||||
@@ -255,6 +255,10 @@ func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
|
|||||||
defer w.muxAgent.Unlock()
|
defer w.muxAgent.Unlock()
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
|
if w.agent == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.agent.Close(); err != nil {
|
if err := w.agent.Close(); err != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
87
client/internal/peerstore/store.go
Normal file
87
client/internal/peerstore/store.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package peerstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is a thread-safe store for peer connections.
|
||||||
|
type Store struct {
|
||||||
|
peerConns map[string]*peer.Conn
|
||||||
|
peerConnsMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnStore() *Store {
|
||||||
|
return &Store{
|
||||||
|
peerConns: make(map[string]*peer.Conn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool {
|
||||||
|
s.peerConnsMu.Lock()
|
||||||
|
defer s.peerConnsMu.Unlock()
|
||||||
|
|
||||||
|
_, ok := s.peerConns[pubKey]
|
||||||
|
if ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.peerConns[pubKey] = conn
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
|
||||||
|
s.peerConnsMu.Lock()
|
||||||
|
defer s.peerConnsMu.Unlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
delete(s.peerConns, pubKey)
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return p.WgConfig().AllowedIps, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return p.AllowedIP(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) PeersPubKey() []string {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
return maps.Keys(s.peerConns)
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
runtime "runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
@@ -13,12 +14,20 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
handlerTypeDynamic = iota
|
||||||
|
handlerTypeDomain
|
||||||
|
handlerTypeStatic
|
||||||
|
)
|
||||||
|
|
||||||
type routerPeerStatus struct {
|
type routerPeerStatus struct {
|
||||||
connected bool
|
connected bool
|
||||||
relayed bool
|
relayed bool
|
||||||
@@ -53,7 +62,18 @@ type clientNetwork struct {
|
|||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
func newClientNetworkWatcher(
|
||||||
|
ctx context.Context,
|
||||||
|
dnsRouteInterval time.Duration,
|
||||||
|
wgInterface iface.IWGIface,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
|
useNewDNSRoute bool,
|
||||||
|
) *clientNetwork {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
@@ -65,7 +85,17 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
|
|||||||
routePeersNotifiers: make(map[string]chan struct{}),
|
routePeersNotifiers: make(map[string]chan struct{}),
|
||||||
routeUpdate: make(chan routesUpdate),
|
routeUpdate: make(chan routesUpdate),
|
||||||
peerStateUpdate: make(chan struct{}),
|
peerStateUpdate: make(chan struct{}),
|
||||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
|
handler: handlerFromRoute(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
dnsRouteInterval,
|
||||||
|
statusRecorder,
|
||||||
|
wgInterface,
|
||||||
|
dnsServer,
|
||||||
|
peerStore,
|
||||||
|
useNewDNSRoute,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
@@ -368,10 +398,50 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
|
func handlerFromRoute(
|
||||||
if rt.IsDynamic() {
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
dnsRouterInteval time.Duration,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
wgInterface iface.IWGIface,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
|
useNewDNSRoute bool,
|
||||||
|
) RouteHandler {
|
||||||
|
switch handlerType(rt, useNewDNSRoute) {
|
||||||
|
case handlerTypeDomain:
|
||||||
|
return dnsinterceptor.New(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
statusRecorder,
|
||||||
|
dnsServer,
|
||||||
|
peerStore,
|
||||||
|
)
|
||||||
|
case handlerTypeDynamic:
|
||||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
return dynamic.NewRoute(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
dnsRouterInteval,
|
||||||
|
statusRecorder,
|
||||||
|
wgInterface,
|
||||||
|
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||||
}
|
}
|
||||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
}
|
||||||
|
|
||||||
|
func handlerType(rt *route.Route, useNewDNSRoute bool) int {
|
||||||
|
if !rt.IsDynamic() {
|
||||||
|
return handlerTypeStatic
|
||||||
|
}
|
||||||
|
|
||||||
|
if useNewDNSRoute && runtime.GOOS != "ios" {
|
||||||
|
return handlerTypeDomain
|
||||||
|
}
|
||||||
|
return handlerTypeDynamic
|
||||||
}
|
}
|
||||||
|
|||||||
356
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
356
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
@@ -0,0 +1,356 @@
|
|||||||
|
package dnsinterceptor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type domainMap map[domain.Domain][]netip.Prefix
|
||||||
|
|
||||||
|
type DnsInterceptor struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
route *route.Route
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter
|
||||||
|
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
dnsServer nbdns.Server
|
||||||
|
currentPeerKey string
|
||||||
|
interceptedDomains domainMap
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
|
) *DnsInterceptor {
|
||||||
|
return &DnsInterceptor{
|
||||||
|
route: rt,
|
||||||
|
routeRefCounter: routeRefCounter,
|
||||||
|
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
dnsServer: dnsServer,
|
||||||
|
interceptedDomains: make(domainMap),
|
||||||
|
peerStore: peerStore,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) String() string {
|
||||||
|
return d.route.Domains.SafeString()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||||
|
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) RemoveRoute() error {
|
||||||
|
d.mu.Lock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
if d.currentPeerKey != "" {
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||||
|
|
||||||
|
}
|
||||||
|
for _, domain := range d.route.Domains {
|
||||||
|
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
clear(d.interceptedDomains)
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||||
|
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||||
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
|
prefix.Addr(),
|
||||||
|
domain.SafeString(),
|
||||||
|
ref.Out,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.currentPeerKey = peerKey
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, prefixes := range d.interceptedDomains {
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.currentPeerKey = ""
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS implements the dns.Handler interface
|
||||||
|
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||||
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
|
d.mu.RLock()
|
||||||
|
peerKey := d.currentPeerKey
|
||||||
|
d.mu.RUnlock()
|
||||||
|
|
||||||
|
if peerKey == "" {
|
||||||
|
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
|
||||||
|
|
||||||
|
d.continueToNextHandler(w, r, "no current peer key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get upstream IP: %v", err)
|
||||||
|
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &dns.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
Net: "udp",
|
||||||
|
}
|
||||||
|
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
||||||
|
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||||
|
|
||||||
|
var answer []dns.RR
|
||||||
|
if reply != nil {
|
||||||
|
answer = reply.Answer
|
||||||
|
}
|
||||||
|
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||||
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
|
log.Errorf("failed writing DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reply.Id = r.Id
|
||||||
|
if err := d.writeMsg(w, reply); err != nil {
|
||||||
|
log.Errorf("failed writing DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// continueToNextHandler signals the handler chain to try the next handler
|
||||||
|
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||||
|
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||||
|
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
// Set Zero bit to signal handler chain to continue
|
||||||
|
resp.MsgHdr.Zero = true
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed writing DNS continue response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
|
||||||
|
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
|
||||||
|
}
|
||||||
|
return peerAllowedIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||||
|
if r == nil {
|
||||||
|
return fmt.Errorf("received nil DNS message")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||||
|
origPattern := ""
|
||||||
|
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
||||||
|
origPattern = writer.GetOrigPattern()
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedDomain := domain.Domain(r.Question[0].Name)
|
||||||
|
|
||||||
|
// already punycode via RegisterHandler()
|
||||||
|
originalDomain := domain.Domain(origPattern)
|
||||||
|
if originalDomain == "" {
|
||||||
|
originalDomain = resolvedDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
var newPrefixes []netip.Prefix
|
||||||
|
for _, answer := range r.Answer {
|
||||||
|
var ip netip.Addr
|
||||||
|
switch rr := answer.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.A)
|
||||||
|
if !ok {
|
||||||
|
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
case *dns.AAAA:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
|
if !ok {
|
||||||
|
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(ip, ip.BitLen())
|
||||||
|
newPrefixes = append(newPrefixes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newPrefixes) > 0 {
|
||||||
|
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||||
|
log.Errorf("failed to update domain prefixes: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(r); err != nil {
|
||||||
|
return fmt.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
oldPrefixes := d.interceptedDomains[resolvedDomain]
|
||||||
|
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Add new prefixes
|
||||||
|
for _, prefix := range toAdd {
|
||||||
|
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.currentPeerKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||||
|
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||||
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
|
prefix.Addr(),
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
ref.Out,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !d.route.KeepRoute {
|
||||||
|
// Remove old prefixes
|
||||||
|
for _, prefix := range toRemove {
|
||||||
|
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
if d.currentPeerKey != "" {
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update domain prefixes using resolved domain as key
|
||||||
|
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||||
|
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||||
|
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||||
|
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
|
||||||
|
|
||||||
|
if len(toAdd) > 0 {
|
||||||
|
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toAdd)
|
||||||
|
}
|
||||||
|
if len(toRemove) > 0 {
|
||||||
|
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toRemove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||||
|
prefixSet := make(map[netip.Prefix]bool)
|
||||||
|
for _, prefix := range oldPrefixes {
|
||||||
|
prefixSet[prefix] = false
|
||||||
|
}
|
||||||
|
for _, prefix := range newPrefixes {
|
||||||
|
if _, exists := prefixSet[prefix]; exists {
|
||||||
|
prefixSet[prefix] = true
|
||||||
|
} else {
|
||||||
|
toAdd = append(toAdd, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for prefix, inUse := range prefixSet {
|
||||||
|
if !inUse {
|
||||||
|
toRemove = append(toRemove, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -74,11 +74,7 @@ func NewRoute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) String() string {
|
func (r *Route) String() string {
|
||||||
s, err := r.route.Domains.String()
|
return r.route.Domains.SafeString()
|
||||||
if err != nil {
|
|
||||||
return r.route.Domains.PunycodeString()
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) AddRoute(ctx context.Context) error {
|
func (r *Route) AddRoute(ctx context.Context) error {
|
||||||
@@ -292,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
|
|||||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||||
r.dynamicDomains[domain] = updatedPrefixes
|
r.dynamicDomains[domain] = updatedPrefixes
|
||||||
|
|
||||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
|
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
|||||||
@@ -12,12 +12,16 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
@@ -33,15 +37,32 @@ import (
|
|||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error
|
||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
|
GetClientRoutes() route.HAMap
|
||||||
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
EnableServerRouter(firewall firewall.Manager) error
|
EnableServerRouter(firewall firewall.Manager) error
|
||||||
Stop(stateManager *statemanager.Manager)
|
Stop(stateManager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ManagerConfig struct {
|
||||||
|
Context context.Context
|
||||||
|
PublicKey string
|
||||||
|
DNSRouteInterval time.Duration
|
||||||
|
WGInterface iface.IWGIface
|
||||||
|
StatusRecorder *peer.Status
|
||||||
|
RelayManager *relayClient.Manager
|
||||||
|
InitialRoutes []*route.Route
|
||||||
|
StateManager *statemanager.Manager
|
||||||
|
DNSServer dns.Server
|
||||||
|
PeerStore *peerstore.Store
|
||||||
|
DisableClientRoutes bool
|
||||||
|
DisableServerRoutes bool
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultManager is the default instance of a route manager
|
// DefaultManager is the default instance of a route manager
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -49,7 +70,7 @@ type DefaultManager struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||||
routeSelector *routeselector.RouteSelector
|
routeSelector *routeselector.RouteSelector
|
||||||
serverRouter serverRouter
|
serverRouter *serverRouter
|
||||||
sysOps *systemops.SysOps
|
sysOps *systemops.SysOps
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
relayMgr *relayClient.Manager
|
relayMgr *relayClient.Manager
|
||||||
@@ -60,52 +81,80 @@ type DefaultManager struct {
|
|||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
|
clientRoutes route.HAMap
|
||||||
|
dnsServer dns.Server
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
useNewDNSRoute bool
|
||||||
|
disableClientRoutes bool
|
||||||
|
disableServerRoutes bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(
|
func NewManager(config ManagerConfig) *DefaultManager {
|
||||||
ctx context.Context,
|
mCTX, cancel := context.WithCancel(config.Context)
|
||||||
pubKey string,
|
|
||||||
dnsRouteInterval time.Duration,
|
|
||||||
wgInterface iface.IWGIface,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
relayMgr *relayClient.Manager,
|
|
||||||
initialRoutes []*route.Route,
|
|
||||||
stateManager *statemanager.Manager,
|
|
||||||
) *DefaultManager {
|
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
sysOps := systemops.NewSysOps(wgInterface, notifier)
|
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||||
|
|
||||||
dm := &DefaultManager{
|
dm := &DefaultManager{
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
dnsRouteInterval: dnsRouteInterval,
|
dnsRouteInterval: config.DNSRouteInterval,
|
||||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||||
relayMgr: relayMgr,
|
relayMgr: config.RelayManager,
|
||||||
sysOps: sysOps,
|
sysOps: sysOps,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: config.StatusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: config.WGInterface,
|
||||||
pubKey: pubKey,
|
pubKey: config.PublicKey,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
stateManager: stateManager,
|
stateManager: config.StateManager,
|
||||||
|
dnsServer: config.DNSServer,
|
||||||
|
peerStore: config.PeerStore,
|
||||||
|
disableClientRoutes: config.DisableClientRoutes,
|
||||||
|
disableServerRoutes: config.DisableServerRoutes,
|
||||||
}
|
}
|
||||||
|
|
||||||
dm.routeRefCounter = refcounter.New(
|
// don't proceed with client routes if it is disabled
|
||||||
|
if config.DisableClientRoutes {
|
||||||
|
return dm
|
||||||
|
}
|
||||||
|
|
||||||
|
dm.setupRefCounters()
|
||||||
|
|
||||||
|
if runtime.GOOS == "android" {
|
||||||
|
cr := dm.initialClientRoutes(config.InitialRoutes)
|
||||||
|
dm.notifier.SetInitialClientRoutes(cr)
|
||||||
|
}
|
||||||
|
return dm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) setupRefCounters() {
|
||||||
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ struct{}) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
dm.allowedIPsRefCounter = refcounter.New(
|
if netstack.IsEnabled() {
|
||||||
|
m.routeRefCounter = refcounter.New(
|
||||||
|
func(netip.Prefix, struct{}) (struct{}, error) {
|
||||||
|
return struct{}{}, refcounter.ErrIgnore
|
||||||
|
},
|
||||||
|
func(netip.Prefix, struct{}) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.allowedIPsRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, peerKey string) (string, error) {
|
func(prefix netip.Prefix, peerKey string) (string, error) {
|
||||||
// save peerKey to use it in the remove function
|
// save peerKey to use it in the remove function
|
||||||
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
|
return peerKey, m.wgInterface.AddAllowedIP(peerKey, prefix.String())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, peerKey string) error {
|
func(prefix netip.Prefix, peerKey string) error {
|
||||||
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
if err := m.wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
||||||
if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
|
if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -114,19 +163,13 @@ func NewManager(
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
|
||||||
cr := dm.clientRoutes(initialRoutes)
|
|
||||||
dm.notifier.SetInitialClientRoutes(cr)
|
|
||||||
}
|
|
||||||
return dm
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init sets up the routing
|
// Init sets up the routing
|
||||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
m.routeSelector = m.initSelector()
|
m.routeSelector = m.initSelector()
|
||||||
|
|
||||||
if nbnet.CustomRoutingDisabled() {
|
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,6 +215,15 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||||
|
if m.disableServerRoutes {
|
||||||
|
log.Info("server routes are disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if firewall == nil {
|
||||||
|
return errors.New("firewall manager is not set")
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -198,7 +250,7 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !nbnet.CustomRoutingDisabled() {
|
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
||||||
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", err)
|
log.Errorf("Error cleaning up routing: %v", err)
|
||||||
} else {
|
} else {
|
||||||
@@ -207,33 +259,43 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.ctx = nil
|
m.ctx = nil
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.clientRoutes = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("not updating routes as context is closed")
|
log.Infof("not updating routes as context is closed")
|
||||||
return nil, nil, m.ctx.Err()
|
return nil
|
||||||
default:
|
default:
|
||||||
m.mux.Lock()
|
}
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.useNewDNSRoute = useNewDNSRoute
|
||||||
|
|
||||||
|
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||||
|
|
||||||
|
if !m.disableClientRoutes {
|
||||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||||
|
|
||||||
if m.serverRouter != nil {
|
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("update routes: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return newServerRoutesMap, newClientRoutesIDMap, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.serverRouter != nil {
|
||||||
|
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.clientRoutes = newClientRoutesIDMap
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRouteChangeListener set RouteListener for route change Notifier
|
// SetRouteChangeListener set RouteListener for route change Notifier
|
||||||
@@ -251,9 +313,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
return m.routeSelector
|
return m.routeSelector
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the client routes
|
// GetClientRoutes returns most recent list of clientRoutes received from the Management Service
|
||||||
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
|
func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
||||||
return m.clientNetworks
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
return maps.Clone(m.clientRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||||
|
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes))
|
||||||
|
for id, v := range m.clientRoutes {
|
||||||
|
routes[id.NetID()] = v
|
||||||
|
}
|
||||||
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||||
@@ -273,7 +350,18 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
clientNetworkWatcher := newClientNetworkWatcher(
|
||||||
|
m.ctx,
|
||||||
|
m.dnsRouteInterval,
|
||||||
|
m.wgInterface,
|
||||||
|
m.statusRecorder,
|
||||||
|
routes[0],
|
||||||
|
m.routeRefCounter,
|
||||||
|
m.allowedIPsRefCounter,
|
||||||
|
m.dnsServer,
|
||||||
|
m.peerStore,
|
||||||
|
m.useNewDNSRoute,
|
||||||
|
)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||||
@@ -302,7 +390,18 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
|||||||
for id, routes := range networks {
|
for id, routes := range networks {
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
if !found {
|
if !found {
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
clientNetworkWatcher = newClientNetworkWatcher(
|
||||||
|
m.ctx,
|
||||||
|
m.dnsRouteInterval,
|
||||||
|
m.wgInterface,
|
||||||
|
m.statusRecorder,
|
||||||
|
routes[0],
|
||||||
|
m.routeRefCounter,
|
||||||
|
m.allowedIPsRefCounter,
|
||||||
|
m.dnsServer,
|
||||||
|
m.peerStore,
|
||||||
|
m.useNewDNSRoute,
|
||||||
|
)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
}
|
}
|
||||||
@@ -345,7 +444,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
|||||||
return newServerRoutesMap, newClientRoutesIDMap
|
return newServerRoutesMap, newClientRoutesIDMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||||
_, crMap := m.classifyRoutes(initialRoutes)
|
_, crMap := m.classifyRoutes(initialRoutes)
|
||||||
rs := make([]*route.Route, 0, len(crMap))
|
rs := make([]*route.Route, 0, len(crMap))
|
||||||
for _, routes := range crMap {
|
for _, routes := range crMap {
|
||||||
|
|||||||
@@ -424,7 +424,12 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
statusRecorder := peer.NewRecorder("https://mgm")
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
|
routeManager := NewManager(ManagerConfig{
|
||||||
|
Context: ctx,
|
||||||
|
PublicKey: localPeerKey,
|
||||||
|
WGInterface: wgInterface,
|
||||||
|
StatusRecorder: statusRecorder,
|
||||||
|
})
|
||||||
|
|
||||||
_, _, err = routeManager.Init()
|
_, _, err = routeManager.Init()
|
||||||
|
|
||||||
@@ -436,11 +441,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(testCase.inputInitRoutes) > 0 {
|
if len(testCase.inputInitRoutes) > 0 {
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
|
||||||
require.NoError(t, err, "should update routes with init routes")
|
require.NoError(t, err, "should update routes with init routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
|
||||||
require.NoError(t, err, "should update routes")
|
require.NoError(t, err, "should update routes")
|
||||||
|
|
||||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||||
@@ -450,8 +455,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||||
|
|
||||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||||
sr := routeManager.serverRouter.(*defaultServerRouter)
|
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||||
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
@@ -15,10 +14,12 @@ import (
|
|||||||
|
|
||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
type MockManager struct {
|
type MockManager struct {
|
||||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
TriggerSelectionFunc func(haMap route.HAMap)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
StopFunc func(manager *statemanager.Manager)
|
GetClientRoutesFunc func() route.HAMap
|
||||||
|
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||||
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||||
@@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error {
|
||||||
if m.UpdateRoutesFunc != nil {
|
if m.UpdateRoutesFunc != nil {
|
||||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||||
}
|
}
|
||||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
||||||
@@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||||
|
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||||
|
if m.GetClientRoutesFunc != nil {
|
||||||
|
return m.GetClientRoutesFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||||
|
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
|
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||||
|
return m.GetClientRoutesWithNetIDFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Start mock implementation of Start from Manager interface
|
// Start mock implementation of Start from Manager interface
|
||||||
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/route"
|
|
||||||
|
|
||||||
type serverRouter interface {
|
|
||||||
updateRoutes(map[route.ID]*route.Route) error
|
|
||||||
removeFromServerNetwork(*route.Route) error
|
|
||||||
cleanUp()
|
|
||||||
}
|
|
||||||
@@ -9,8 +9,19 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
|
type serverRouter struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r serverRouter) cleanUp() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
|
||||||
return nil, fmt.Errorf("server route not supported on this os")
|
return nil, fmt.Errorf("server route not supported on this os")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type defaultServerRouter struct {
|
type serverRouter struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
routes map[route.ID]*route.Route
|
routes map[route.ID]*route.Route
|
||||||
@@ -26,8 +26,8 @@ type defaultServerRouter struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||||
return &defaultServerRouter{
|
return &serverRouter{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
routes: make(map[route.ID]*route.Route),
|
routes: make(map[route.ID]*route.Route),
|
||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
@@ -36,7 +36,7 @@ func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall f
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
||||||
serverRoutesToRemove := make([]route.ID, 0)
|
serverRoutesToRemove := make([]route.ID, 0)
|
||||||
|
|
||||||
for routeID := range m.routes {
|
for routeID := range m.routes {
|
||||||
@@ -80,74 +80,72 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||||
select {
|
if m.ctx.Err() != nil {
|
||||||
case <-m.ctx.Done():
|
|
||||||
log.Infof("Not removing from server network because context is done")
|
log.Infof("Not removing from server network because context is done")
|
||||||
return m.ctx.Err()
|
return m.ctx.Err()
|
||||||
default:
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
routerPair, err := routeToRouterPair(route)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.firewall.RemoveNatRule(routerPair)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("remove routing rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(m.routes, route.ID)
|
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
|
||||||
delete(state.Routes, route.Network.String())
|
|
||||||
m.statusRecorder.UpdateLocalPeerState(state)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
routerPair, err := routeToRouterPair(route)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.firewall.RemoveNatRule(routerPair)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("remove routing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.routes, route.ID)
|
||||||
|
|
||||||
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
|
delete(state.Routes, route.Network.String())
|
||||||
|
m.statusRecorder.UpdateLocalPeerState(state)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
||||||
select {
|
if m.ctx.Err() != nil {
|
||||||
case <-m.ctx.Done():
|
|
||||||
log.Infof("Not adding to server network because context is done")
|
log.Infof("Not adding to server network because context is done")
|
||||||
return m.ctx.Err()
|
return m.ctx.Err()
|
||||||
default:
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
routerPair, err := routeToRouterPair(route)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.firewall.AddNatRule(routerPair)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("insert routing rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.routes[route.ID] = route
|
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
|
||||||
if state.Routes == nil {
|
|
||||||
state.Routes = map[string]struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
routeStr := route.Network.String()
|
|
||||||
if route.IsDynamic() {
|
|
||||||
routeStr = route.Domains.SafeString()
|
|
||||||
}
|
|
||||||
state.Routes[routeStr] = struct{}{}
|
|
||||||
|
|
||||||
m.statusRecorder.UpdateLocalPeerState(state)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
routerPair, err := routeToRouterPair(route)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.firewall.AddNatRule(routerPair)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert routing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routes[route.ID] = route
|
||||||
|
|
||||||
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
|
if state.Routes == nil {
|
||||||
|
state.Routes = map[string]struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
routeStr := route.Network.String()
|
||||||
|
if route.IsDynamic() {
|
||||||
|
routeStr = route.Domains.SafeString()
|
||||||
|
}
|
||||||
|
state.Routes[routeStr] = struct{}{}
|
||||||
|
|
||||||
|
m.statusRecorder.UpdateLocalPeerState(state)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) cleanUp() {
|
func (m *serverRouter) cleanUp() {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
for _, r := range m.routes {
|
for _, r := range m.routes {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
@@ -62,6 +63,17 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
|||||||
r.removeFromRouteTable,
|
r.removeFromRouteTable,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
refCounter = refcounter.New(
|
||||||
|
func(netip.Prefix, struct{}) (Nexthop, error) {
|
||||||
|
return Nexthop{}, refcounter.ErrIgnore
|
||||||
|
},
|
||||||
|
func(netip.Prefix, Nexthop) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
r.refCounter = refCounter
|
r.refCounter = refCounter
|
||||||
|
|
||||||
return r.setupHooks(initAddresses, stateManager)
|
return r.setupHooks(initAddresses, stateManager)
|
||||||
|
|||||||
@@ -266,7 +266,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
|||||||
return fmt.Errorf("add gateway and device: %w", err)
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||||
return fmt.Errorf("netlink add route: %w", err)
|
return fmt.Errorf("netlink add route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -289,7 +289,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
|||||||
Dst: ipNet,
|
Dst: ipNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||||
return fmt.Errorf("netlink add unreachable route: %w", err)
|
return fmt.Errorf("netlink add unreachable route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,7 +312,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
|||||||
if err := netlink.RouteDel(route); err != nil &&
|
if err := netlink.RouteDel(route); err != nil &&
|
||||||
!errors.Is(err, syscall.ESRCH) &&
|
!errors.Is(err, syscall.ESRCH) &&
|
||||||
!errors.Is(err, syscall.ENOENT) &&
|
!errors.Is(err, syscall.ENOENT) &&
|
||||||
!errors.Is(err, syscall.EAFNOSUPPORT) {
|
!isOpErr(err) {
|
||||||
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,7 +338,7 @@ func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
|||||||
return fmt.Errorf("add gateway and device: %w", err)
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !isOpErr(err) {
|
||||||
return fmt.Errorf("netlink remove route: %w", err)
|
return fmt.Errorf("netlink remove route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,7 +362,7 @@ func flushRoutes(tableID, family int) error {
|
|||||||
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
|
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RouteDel(&routes[i]); err != nil && !isOpErr(err) {
|
||||||
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
|
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
|
|||||||
rule.Invert = params.invert
|
rule.Invert = params.invert
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||||
return fmt.Errorf("add routing rule: %w", err)
|
return fmt.Errorf("add routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
|
|||||||
rule.Priority = params.priority
|
rule.Priority = params.priority
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !isOpErr(err) {
|
||||||
return fmt.Errorf("remove routing rule: %w", err)
|
return fmt.Errorf("remove routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -509,3 +509,13 @@ func hasSeparateRouting() ([]netip.Prefix, error) {
|
|||||||
}
|
}
|
||||||
return nil, ErrRoutingIsSeparate
|
return nil, ErrRoutingIsSeparate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOpErr(err error) bool {
|
||||||
|
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
||||||
|
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
|
log.Debugf("route operation not supported: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -303,20 +303,29 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
|
|||||||
|
|
||||||
var rawStates map[string]json.RawMessage
|
var rawStates map[string]json.RawMessage
|
||||||
if err := json.Unmarshal(data, &rawStates); err != nil {
|
if err := json.Unmarshal(data, &rawStates); err != nil {
|
||||||
if deleteCorrupt {
|
m.handleCorruptedState(deleteCorrupt)
|
||||||
log.Warn("State file appears to be corrupted, attempting to delete it", err)
|
|
||||||
if err := os.Remove(m.filePath); err != nil {
|
|
||||||
log.Errorf("Failed to delete corrupted state file: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Info("State file deleted")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unmarshal states: %w", err)
|
return nil, fmt.Errorf("unmarshal states: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawStates, nil
|
return rawStates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleCorruptedState creates a backup of a corrupted state file by moving it
|
||||||
|
func (m *Manager) handleCorruptedState(deleteCorrupt bool) {
|
||||||
|
if !deleteCorrupt {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Warn("State file appears to be corrupted, attempting to back it up")
|
||||||
|
|
||||||
|
backupPath := fmt.Sprintf("%s.corrupted.%d", m.filePath, time.Now().UnixNano())
|
||||||
|
if err := os.Rename(m.filePath, backupPath); err != nil {
|
||||||
|
log.Errorf("Failed to backup corrupted state file: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Created backup of corrupted state file at: %s", backupPath)
|
||||||
|
}
|
||||||
|
|
||||||
// loadSingleRawState unmarshals a raw state into a concrete state object
|
// loadSingleRawState unmarshals a raw state into a concrete state object
|
||||||
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
|
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
|
||||||
stateType, ok := m.stateTypes[name]
|
stateType, ok := m.stateTypes[name]
|
||||||
|
|||||||
@@ -1,23 +1,16 @@
|
|||||||
package statemanager
|
package statemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||||
// It returns an empty string if the path cannot be determined.
|
// It returns an empty string if the path cannot be determined.
|
||||||
func GetDefaultStatePath() string {
|
func GetDefaultStatePath() string {
|
||||||
switch runtime.GOOS {
|
if path := os.Getenv("NB_DNS_STATE_FILE"); path != "" {
|
||||||
case "windows":
|
return path
|
||||||
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
|
||||||
case "darwin", "linux":
|
|
||||||
return "/var/lib/netbird/state.json"
|
|
||||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
|
||||||
return "/var/db/netbird/state.json"
|
|
||||||
}
|
}
|
||||||
|
return filepath.Join(configs.StateDir, "state.json")
|
||||||
return ""
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ func (c *Client) IsLoginRequired() bool {
|
|||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
|
|
||||||
needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey)
|
needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
|
||||||
return needsLogin
|
return needsLogin
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -272,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
|||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
routesMap := engine.GetClientRoutesWithNetID()
|
|
||||||
routeManager := engine.GetRouteManager()
|
routeManager := engine.GetRouteManager()
|
||||||
|
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||||
if routeManager == nil {
|
if routeManager == nil {
|
||||||
return nil, fmt.Errorf("could not get route manager")
|
return nil, fmt.Errorf("could not get route manager")
|
||||||
}
|
}
|
||||||
@@ -317,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
|
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
|
||||||
var routeSelection []RoutesSelectionInfo
|
var routeSelection []RoutesSelectionInfo
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
domainList := make([]DomainInfo, 0)
|
domainList := make([]DomainInfo, 0)
|
||||||
@@ -325,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
|||||||
domainResp := DomainInfo{
|
domainResp := DomainInfo{
|
||||||
Domain: d.SafeString(),
|
Domain: d.SafeString(),
|
||||||
}
|
}
|
||||||
if prefixes, exists := resolvedDomains[d]; exists {
|
|
||||||
|
if info, exists := resolvedDomains[d]; exists {
|
||||||
var ipStrings []string
|
var ipStrings []string
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range info.Prefixes {
|
||||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
ipStrings = append(ipStrings, prefix.Addr().String())
|
||||||
}
|
}
|
||||||
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
|
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
|
||||||
@@ -365,12 +366,12 @@ func (c *Client) SelectRoute(id string) error {
|
|||||||
} else {
|
} else {
|
||||||
log.Debugf("select route with id: %s", id)
|
log.Debugf("select route with id: %s", id)
|
||||||
routes := toNetIDs([]string{id})
|
routes := toNetIDs([]string{id})
|
||||||
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||||
log.Debugf("error when selecting routes: %s", err)
|
log.Debugf("error when selecting routes: %s", err)
|
||||||
return fmt.Errorf("select routes: %w", err)
|
return fmt.Errorf("select routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -392,12 +393,12 @@ func (c *Client) DeselectRoute(id string) error {
|
|||||||
} else {
|
} else {
|
||||||
log.Debugf("deselect route with id: %s", id)
|
log.Debugf("deselect route with id: %s", id)
|
||||||
routes := toNetIDs([]string{id})
|
routes := toNetIDs([]string{id})
|
||||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||||
log.Debugf("error when deselecting routes: %s", err)
|
log.Debugf("error when deselecting routes: %s", err)
|
||||||
return fmt.Errorf("deselect routes: %w", err)
|
return fmt.Errorf("deselect routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ func (a *Auth) Login() error {
|
|||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user