mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 07:46:38 +00:00
Compare commits
75 Commits
v0.36.4
...
chore/benc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6d8591a6a | ||
|
|
770691fe40 | ||
|
|
1e1c078f9b | ||
|
|
99e02b83f1 | ||
|
|
5b6e93e927 | ||
|
|
6130a48c44 | ||
|
|
f7d73f24c7 | ||
|
|
cc48594b0b | ||
|
|
559e673107 | ||
|
|
b64bee35fa | ||
|
|
9a0354b681 | ||
|
|
73101c8977 | ||
|
|
73ce746ba7 | ||
|
|
a74208abac | ||
|
|
c02b55c911 | ||
|
|
d7523e6f3f | ||
|
|
2644f37025 | ||
|
|
4b70bfd9d7 | ||
|
|
b307298b2f | ||
|
|
33576671b6 | ||
|
|
f00a997167 | ||
|
|
50acf07def | ||
|
|
8ac63ea430 | ||
|
|
974144a381 | ||
|
|
5c81937ed6 | ||
|
|
5134e3a06a | ||
|
|
6554026a82 | ||
|
|
a854660402 | ||
|
|
a0b48f971c | ||
|
|
96de928cb3 | ||
|
|
77e40f41f2 | ||
|
|
640aa872f4 | ||
|
|
5132f23e0f | ||
|
|
1145d1a433 | ||
|
|
d7d5b1b1d6 | ||
|
|
631ef4ed28 | ||
|
|
39986b0e97 | ||
|
|
62a0c358f9 | ||
|
|
87311074f1 | ||
|
|
33cf9535b3 | ||
|
|
7e6beee7f6 | ||
|
|
27b3891b14 | ||
|
|
2a864832c6 | ||
|
|
c974c12d65 | ||
|
|
50926bdbb4 | ||
|
|
bd381d59cd | ||
|
|
f67e56d3b9 | ||
|
|
8fb5a9ce11 | ||
|
|
4cdb2e533a | ||
|
|
abe8da697c | ||
|
|
039a985f41 | ||
|
|
c4a6dafd27 | ||
|
|
a930c2aecf | ||
|
|
d48edb9837 | ||
|
|
b41de7fcd1 | ||
|
|
18f84f0df5 | ||
|
|
44407a158a | ||
|
|
488b697479 | ||
|
|
5953b43ead | ||
|
|
58b2eb4b92 | ||
|
|
05415f72ec | ||
|
|
b7af53ea40 | ||
|
|
cee4aeea9e | ||
|
|
ca9aca9b19 | ||
|
|
e00a280329 | ||
|
|
fe370e7d8f | ||
|
|
125b5e2b16 | ||
|
|
97d498c59c | ||
|
|
0125cd97d8 | ||
|
|
7d385b8dc3 | ||
|
|
f930ef2ee6 | ||
|
|
771c99a523 | ||
|
|
e20be2397c | ||
|
|
46766e7e24 | ||
|
|
a7ddb8f1f8 |
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Test Code Darwin
|
||||
name: "Darwin"
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,9 +12,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
store: ['sqlite']
|
||||
name: "Client / Unit"
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
|
||||
4
.github/workflows/golang-test-freebsd.yml
vendored
4
.github/workflows/golang-test-freebsd.yml
vendored
@@ -1,5 +1,4 @@
|
||||
|
||||
name: Test Code FreeBSD
|
||||
name: "FreeBSD"
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -13,6 +12,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
169
.github/workflows/golang-test-linux.yml
vendored
169
.github/workflows/golang-test-linux.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Test Code Linux
|
||||
name: Linux
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,11 +12,21 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-cache:
|
||||
name: "Build Cache"
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
management:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
@@ -38,7 +48,6 @@ jobs:
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
@@ -89,6 +98,7 @@ jobs:
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -134,9 +144,116 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
|
||||
test_relay:
|
||||
name: "Relay / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./signal/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./signal/...
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -194,10 +311,17 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
go test -tags=devcert \
|
||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||
-timeout 10m ./management/...
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [ build-cache ]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -254,16 +378,24 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
go test -tags=internal_benchmark -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./...
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
needs: [ build-cache ]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'postgres' ]
|
||||
run: [ '1', '2', '3', '4', '5']
|
||||
runs-on: ubuntu-latest-m
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
@@ -312,12 +444,21 @@ jobs:
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
go test -p 1 -tags=api_benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/...
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
needs: [ build-cache ]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -363,9 +504,15 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
go test -tags=integration \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 10m ./management/...
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
|
||||
3
.github/workflows/golang-test-windows.yml
vendored
3
.github/workflows/golang-test-windows.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Test Code Windows
|
||||
name: "Windows"
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -14,6 +14,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
11
.github/workflows/golangci-lint.yml
vendored
11
.github/workflows/golangci-lint.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: golangci-lint
|
||||
name: Lint
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
@@ -27,7 +27,14 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest, windows-latest, ubuntu-latest]
|
||||
name: lint
|
||||
include:
|
||||
- os: macos-latest
|
||||
display_name: Darwin
|
||||
- os: windows-latest
|
||||
display_name: Windows
|
||||
- os: ubuntu-latest
|
||||
display_name: Linux
|
||||
name: ${{ matrix.display_name }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Mobile build validation
|
||||
name: Mobile
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,6 +12,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
android_build:
|
||||
name: "Android / Build"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -47,6 +48,7 @@ jobs:
|
||||
CGO_ENABLED: 0
|
||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||
ios_build:
|
||||
name: "iOS / Build"
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -9,10 +9,10 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.17"
|
||||
SIGN_PIPE_VER: "v0.0.18"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,3 +29,4 @@ infrastructure_files/setup.env
|
||||
infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
vendor/
|
||||
|
||||
@@ -103,7 +103,7 @@ linters:
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
|
||||
@@ -50,10 +50,12 @@ nfpms:
|
||||
- netbird-ui
|
||||
formats:
|
||||
- deb
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
- src: client/ui/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
@@ -67,10 +69,12 @@ nfpms:
|
||||
- netbird-ui
|
||||
formats:
|
||||
- rpm
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
- src: client/ui/netbird.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
|
||||
2
AUTHORS
2
AUTHORS
@@ -1,3 +1,3 @@
|
||||
Mikhail Bragin (https://github.com/braginini)
|
||||
Maycon Santos (https://github.com/mlsmaycon)
|
||||
Wiretrustee UG (haftungsbeschränkt)
|
||||
NetBird GmbH
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
We are incredibly thankful for the contributions we receive from the community.
|
||||
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
||||
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
||||
as BSD-3 while allowing Wiretrustee to build a sustainable business.
|
||||
as BSD-3 while allowing NetBird to build a sustainable business.
|
||||
|
||||
Wiretrustee is committed to having a true Open Source Software ("OSS") license for
|
||||
our software. A CLA enables Wiretrustee to safely commercialize our products
|
||||
NetBird is committed to having a true Open Source Software ("OSS") license for
|
||||
our software. A CLA enables NetBird to safely commercialize our products
|
||||
while keeping a standard OSS license with all the rights that license grants to users: the
|
||||
ability to use the project in their own projects or businesses, to republish modified
|
||||
source, or to completely fork the project.
|
||||
@@ -20,11 +20,11 @@ This is a human-readable summary of (and not a substitute for) the full agreemen
|
||||
This highlights only some of key terms of the CLA. It has no legal value and you should
|
||||
carefully review all the terms of the actual CLA before agreeing.
|
||||
|
||||
<li>Grant of copyright license. You give Wiretrustee permission to use your copyrighted work
|
||||
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
|
||||
in commercial products.
|
||||
</li>
|
||||
|
||||
<li>Grant of patent license. If your contributed work uses a patent, you give Wiretrustee a
|
||||
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
|
||||
license to use that patent including within commercial products. You also agree that you
|
||||
have permission to grant this license.
|
||||
</li>
|
||||
@@ -45,7 +45,7 @@ more.
|
||||
# Why require a CLA?
|
||||
|
||||
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
||||
to use your contribution at a later date, and that Wiretrustee has permission to use your contribution in our commercial
|
||||
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
|
||||
products.
|
||||
|
||||
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
||||
@@ -65,25 +65,25 @@ Follow the steps given by the bot to sign the CLA. This will require you to log
|
||||
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
||||
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
||||
|
||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any Wiretrustee project will not
|
||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
|
||||
require you to sign again.
|
||||
|
||||
# Legal Terms and Agreement
|
||||
|
||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, Wiretrustee
|
||||
UG (haftungsbeschränkt) ("Wiretrustee") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
|
||||
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
||||
your own Contributions for any other purpose.
|
||||
|
||||
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
||||
Wiretrustee. Except for the license granted herein to Wiretrustee and recipients of software distributed by Wiretrustee,
|
||||
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
|
||||
You reserve all right, title, and interest in and to Your Contributions.
|
||||
|
||||
1. Definitions.
|
||||
|
||||
```
|
||||
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
||||
that is making this Agreement with Wiretrustee. For legal entities, the entity making a Contribution and all other
|
||||
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
|
||||
entities that control, are controlled by, or are under common control with that entity are considered
|
||||
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
||||
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
||||
@@ -91,23 +91,23 @@ You reserve all right, title, and interest in and to Your Contributions.
|
||||
```
|
||||
```
|
||||
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
||||
an existing work, that is or previously has been intentionally submitted by You to Wiretrustee for inclusion in,
|
||||
or documentation of, any of the products owned or managed by Wiretrustee (the "Work").
|
||||
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
|
||||
or documentation of, any of the products owned or managed by NetBird (the "Work").
|
||||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
||||
sent to Wiretrustee or its representatives, including but not limited to communication on electronic mailing lists,
|
||||
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
|
||||
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
||||
Wiretrustee for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||
marked or otherwise designated in writing by You as "Not a Contribution."
|
||||
```
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee
|
||||
and to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge,
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
|
||||
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
|
||||
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
||||
perform, sublicense, and distribute Your Contributions and such derivative works.
|
||||
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee and
|
||||
to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
|
||||
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
||||
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
||||
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
||||
@@ -121,8 +121,8 @@ You reserve all right, title, and interest in and to Your Contributions.
|
||||
intellectual property that you create that includes your Contributions, you represent that you have received
|
||||
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
||||
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
||||
your current and future Contributions to Wiretrustee, or that your employer has executed a separate Corporate CLA
|
||||
with Wiretrustee.
|
||||
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
|
||||
with NetBird.
|
||||
|
||||
|
||||
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
||||
@@ -138,11 +138,11 @@ You reserve all right, title, and interest in and to Your Contributions.
|
||||
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
|
||||
|
||||
7. Should You wish to submit work that is not Your original creation, You may submit it to Wiretrustee separately from
|
||||
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
|
||||
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
||||
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
||||
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||
|
||||
|
||||
8. You agree to notify Wiretrustee of any facts or circumstances of which you become aware that would make these
|
||||
representations inaccurate in any respect.
|
||||
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
|
||||
representations inaccurate in any respect.
|
||||
|
||||
4
LICENSE
4
LICENSE
@@ -1,6 +1,6 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2022 Wiretrustee UG (haftungsbeschränkt) & AUTHORS
|
||||
Copyright (c) 2022 NetBird GmbH & AUTHORS
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
@@ -10,4 +10,4 @@ Redistribution and use in source and binary forms, with or without modification,
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
<div align="center">
|
||||
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
|
||||
Webinar: How to Achieve Zero Trust Access to Kubernetes — Effortlessly
|
||||
</a>
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM alpine:3.21.0
|
||||
FROM alpine:3.21.3
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
|
||||
@@ -9,6 +9,7 @@ USER netbird:netbird
|
||||
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENV NB_USE_NETSTACK_MODE=true
|
||||
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
|
||||
ENV NB_CONFIG=config.json
|
||||
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||
ENV NB_DISABLE_DNS=true
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
@@ -85,7 +86,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -196,7 +197,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
@@ -206,7 +207,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
@@ -271,13 +272,15 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command) string {
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
@@ -85,11 +85,17 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
var dnsLabelsReq []string
|
||||
if dnsLabelsValidated != nil {
|
||||
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
|
||||
@@ -2,107 +2,20 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type peerStateDetailOutput struct {
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
||||
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
|
||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
Networks []string `json:"networks" yaml:"networks"`
|
||||
}
|
||||
|
||||
type peersStateOutput struct {
|
||||
Total int `json:"total" yaml:"total"`
|
||||
Connected int `json:"connected" yaml:"connected"`
|
||||
Details []peerStateDetailOutput `json:"details" yaml:"details"`
|
||||
}
|
||||
|
||||
type signalStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type managementStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutputDetail struct {
|
||||
URI string `json:"uri" yaml:"uri"`
|
||||
Available bool `json:"available" yaml:"available"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutput struct {
|
||||
Total int `json:"total" yaml:"total"`
|
||||
Available int `json:"available" yaml:"available"`
|
||||
Details []relayStateOutputDetail `json:"details" yaml:"details"`
|
||||
}
|
||||
|
||||
type iceCandidateType struct {
|
||||
Local string `json:"local" yaml:"local"`
|
||||
Remote string `json:"remote" yaml:"remote"`
|
||||
}
|
||||
|
||||
type nsServerGroupStateOutput struct {
|
||||
Servers []string `json:"servers" yaml:"servers"`
|
||||
Domains []string `json:"domains" yaml:"domains"`
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type statusOutputOverview struct {
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
Networks []string `json:"networks" yaml:"networks"`
|
||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||
}
|
||||
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
@@ -173,18 +86,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
case jsonFlag:
|
||||
statusOutputString, err = parseToJSON(outputInformationHolder)
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
case yamlFlag:
|
||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -214,7 +126,6 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "disconnected", "connected":
|
||||
if strings.ToLower(statusFilter) != "" {
|
||||
@@ -251,175 +162,6 @@ func enableDetailFlagWhenFilterFlag() {
|
||||
}
|
||||
}
|
||||
|
||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
|
||||
managementState := pbFullStatus.GetManagementState()
|
||||
managementOverview := managementStateOutput{
|
||||
URL: managementState.GetURL(),
|
||||
Connected: managementState.GetConnected(),
|
||||
Error: managementState.Error,
|
||||
}
|
||||
|
||||
signalState := pbFullStatus.GetSignalState()
|
||||
signalOverview := signalStateOutput{
|
||||
URL: signalState.GetURL(),
|
||||
Connected: signalState.GetConnected(),
|
||||
Error: signalState.Error,
|
||||
}
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
||||
|
||||
overview := statusOutputOverview{
|
||||
Peers: peersOverview,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
ManagementState: managementOverview,
|
||||
SignalState: signalOverview,
|
||||
Relays: relayOverview,
|
||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||
}
|
||||
|
||||
if anonymizeFlag {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
anonymizeOverview(anonymizer, &overview)
|
||||
}
|
||||
|
||||
return overview
|
||||
}
|
||||
|
||||
func mapRelays(relays []*proto.RelayState) relayStateOutput {
|
||||
var relayStateDetail []relayStateOutputDetail
|
||||
|
||||
var relaysAvailable int
|
||||
for _, relay := range relays {
|
||||
available := relay.GetAvailable()
|
||||
relayStateDetail = append(relayStateDetail,
|
||||
relayStateOutputDetail{
|
||||
URI: relay.URI,
|
||||
Available: available,
|
||||
Error: relay.GetError(),
|
||||
},
|
||||
)
|
||||
|
||||
if available {
|
||||
relaysAvailable++
|
||||
}
|
||||
}
|
||||
|
||||
return relayStateOutput{
|
||||
Total: len(relays),
|
||||
Available: relaysAvailable,
|
||||
Details: relayStateDetail,
|
||||
}
|
||||
}
|
||||
|
||||
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
||||
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
|
||||
for _, pbNsGroupServer := range servers {
|
||||
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
|
||||
Servers: pbNsGroupServer.GetServers(),
|
||||
Domains: pbNsGroupServer.GetDomains(),
|
||||
Enabled: pbNsGroupServer.GetEnabled(),
|
||||
Error: pbNsGroupServer.GetError(),
|
||||
})
|
||||
}
|
||||
return mappedNSGroups
|
||||
}
|
||||
|
||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
var peersStateDetail []peerStateDetailOutput
|
||||
peersConnected := 0
|
||||
for _, pbPeerState := range peers {
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
relayServerAddress := ""
|
||||
connType := ""
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||
continue
|
||||
}
|
||||
if isPeerConnected {
|
||||
peersConnected++
|
||||
|
||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
||||
connType = "P2P"
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||
transferReceived = pbPeerState.GetBytesRx()
|
||||
transferSent = pbPeerState.GetBytesTx()
|
||||
}
|
||||
|
||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
||||
peerState := peerStateDetailOutput{
|
||||
IP: pbPeerState.GetIP(),
|
||||
PubKey: pbPeerState.GetPubKey(),
|
||||
Status: pbPeerState.GetConnStatus(),
|
||||
LastStatusUpdate: timeLocal,
|
||||
ConnType: connType,
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: localICE,
|
||||
Remote: remoteICE,
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: localICEEndpoint,
|
||||
Remote: remoteICEEndpoint,
|
||||
},
|
||||
RelayAddress: relayServerAddress,
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
LastWireguardHandshake: lastHandshake,
|
||||
TransferReceived: transferReceived,
|
||||
TransferSent: transferSent,
|
||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||
Routes: pbPeerState.GetNetworks(),
|
||||
Networks: pbPeerState.GetNetworks(),
|
||||
}
|
||||
|
||||
peersStateDetail = append(peersStateDetail, peerState)
|
||||
}
|
||||
|
||||
sortPeersByIP(peersStateDetail)
|
||||
|
||||
peersOverview := peersStateOutput{
|
||||
Total: len(peersStateDetail),
|
||||
Connected: peersConnected,
|
||||
Details: peersStateDetail,
|
||||
}
|
||||
return peersOverview
|
||||
}
|
||||
|
||||
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
|
||||
if len(peersStateDetail) > 0 {
|
||||
sort.SliceStable(peersStateDetail, func(i, j int) bool {
|
||||
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
|
||||
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
|
||||
return iAddr.Compare(jAddr) == -1
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func parseInterfaceIP(interfaceIP string) string {
|
||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||
if err != nil {
|
||||
@@ -427,452 +169,3 @@ func parseInterfaceIP(interfaceIP string) string {
|
||||
}
|
||||
return fmt.Sprintf("%s\n", ip)
|
||||
}
|
||||
|
||||
func parseToJSON(overview statusOutputOverview) (string, error) {
|
||||
jsonBytes, err := json.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
func parseToYAML(overview statusOutputOverview) (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
}
|
||||
} else {
|
||||
managementConnString = "Disconnected"
|
||||
if overview.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
var signalConnString string
|
||||
if overview.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
}
|
||||
} else {
|
||||
signalConnString = "Disconnected"
|
||||
if overview.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
interfaceTypeString := "Userspace"
|
||||
interfaceIP := overview.IP
|
||||
if overview.KernelInterface {
|
||||
interfaceTypeString = "Kernel"
|
||||
} else if overview.IP == "" {
|
||||
interfaceTypeString = "N/A"
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
if !relay.Available {
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(overview.Networks) > 0 {
|
||||
sort.Strings(overview.Networks)
|
||||
networks = strings.Join(overview.Networks, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
}
|
||||
errorString := ""
|
||||
if nsServerGroup.Error != "" {
|
||||
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
|
||||
errorString = strings.TrimSpace(errorString)
|
||||
}
|
||||
|
||||
domainsString := strings.Join(nsServerGroup.Domains, ", ")
|
||||
if domainsString == "" {
|
||||
domainsString = "." // Show "." for the default zone
|
||||
}
|
||||
dnsServersString += fmt.Sprintf(
|
||||
"\n [%s] for [%s] is %s%s",
|
||||
strings.Join(nsServerGroup.Servers, ", "),
|
||||
domainsString,
|
||||
enabled,
|
||||
errorString,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if overview.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
if overview.RosenpassPermissive {
|
||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
goarm := ""
|
||||
if goarch == "arm" {
|
||||
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf(
|
||||
"OS: %s\n"+
|
||||
"Daemon version: %s\n"+
|
||||
"CLI version: %s\n"+
|
||||
"Management: %s\n"+
|
||||
"Signal: %s\n"+
|
||||
"Relays: %s\n"+
|
||||
"Nameservers: %s\n"+
|
||||
"FQDN: %s\n"+
|
||||
"NetBird IP: %s\n"+
|
||||
"Interface type: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Routes: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
overview.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
overview.FQDN,
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
networks,
|
||||
networks,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
}
|
||||
|
||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
summary := parseGeneralSummary(overview, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
"%s\n"+
|
||||
"%s",
|
||||
parsedPeersString,
|
||||
summary,
|
||||
)
|
||||
}
|
||||
|
||||
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
||||
var (
|
||||
peersString = ""
|
||||
)
|
||||
|
||||
for _, peerState := range peers.Details {
|
||||
|
||||
localICE := "-"
|
||||
if peerState.IceCandidateType.Local != "" {
|
||||
localICE = peerState.IceCandidateType.Local
|
||||
}
|
||||
|
||||
remoteICE := "-"
|
||||
if peerState.IceCandidateType.Remote != "" {
|
||||
remoteICE = peerState.IceCandidateType.Remote
|
||||
}
|
||||
|
||||
localICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Local != "" {
|
||||
localICEEndpoint = peerState.IceCandidateEndpoint.Local
|
||||
}
|
||||
|
||||
remoteICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Remote != "" {
|
||||
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if rosenpassEnabled {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
} else {
|
||||
if rosenpassPermissive {
|
||||
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
|
||||
} else {
|
||||
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
|
||||
}
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(peerState.Networks) > 0 {
|
||||
sort.Strings(peerState.Networks)
|
||||
networks = strings.Join(peerState.Networks, ", ")
|
||||
}
|
||||
|
||||
peerString := fmt.Sprintf(
|
||||
"\n %s:\n"+
|
||||
" NetBird IP: %s\n"+
|
||||
" Public key: %s\n"+
|
||||
" Status: %s\n"+
|
||||
" -- detail --\n"+
|
||||
" Connection type: %s\n"+
|
||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
||||
" Relay server address: %s\n"+
|
||||
" Last connection update: %s\n"+
|
||||
" Last WireGuard handshake: %s\n"+
|
||||
" Transfer status (received/sent) %s/%s\n"+
|
||||
" Quantum resistance: %s\n"+
|
||||
" Routes: %s\n"+
|
||||
" Networks: %s\n"+
|
||||
" Latency: %s\n",
|
||||
peerState.FQDN,
|
||||
peerState.IP,
|
||||
peerState.PubKey,
|
||||
peerState.Status,
|
||||
peerState.ConnType,
|
||||
localICE,
|
||||
remoteICE,
|
||||
localICEEndpoint,
|
||||
remoteICEEndpoint,
|
||||
peerState.RelayAddress,
|
||||
timeAgo(peerState.LastStatusUpdate),
|
||||
timeAgo(peerState.LastWireguardHandshake),
|
||||
toIEC(peerState.TransferReceived),
|
||||
toIEC(peerState.TransferSent),
|
||||
rosenpassEnabledStatus,
|
||||
networks,
|
||||
networks,
|
||||
peerState.Latency.String(),
|
||||
)
|
||||
|
||||
peersString += peerString
|
||||
}
|
||||
return peersString
|
||||
}
|
||||
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := true
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
if lowerStatusFilter == "disconnected" && isConnected {
|
||||
statusEval = true
|
||||
} else if lowerStatusFilter == "connected" && !isConnected {
|
||||
statusEval = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(ipsFilter) > 0 {
|
||||
_, ok := ipsFilterMap[peerState.IP]
|
||||
if !ok {
|
||||
ipEval = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = false
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
}
|
||||
|
||||
func toIEC(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB",
|
||||
float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
||||
count := 0
|
||||
for _, server := range dnsServers {
|
||||
if server.Enabled {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
|
||||
func timeAgo(t time.Time) string {
|
||||
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
|
||||
return "-"
|
||||
}
|
||||
duration := time.Since(t)
|
||||
switch {
|
||||
case duration < time.Second:
|
||||
return "Now"
|
||||
case duration < time.Minute:
|
||||
seconds := int(duration.Seconds())
|
||||
if seconds == 1 {
|
||||
return "1 second ago"
|
||||
}
|
||||
return fmt.Sprintf("%d seconds ago", seconds)
|
||||
case duration < time.Hour:
|
||||
minutes := int(duration.Minutes())
|
||||
seconds := int(duration.Seconds()) % 60
|
||||
if minutes == 1 {
|
||||
if seconds == 1 {
|
||||
return "1 minute, 1 second ago"
|
||||
} else if seconds > 0 {
|
||||
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
|
||||
}
|
||||
return "1 minute ago"
|
||||
}
|
||||
if seconds > 0 {
|
||||
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
|
||||
}
|
||||
return fmt.Sprintf("%d minutes ago", minutes)
|
||||
case duration < 24*time.Hour:
|
||||
hours := int(duration.Hours())
|
||||
minutes := int(duration.Minutes()) % 60
|
||||
if hours == 1 {
|
||||
if minutes == 1 {
|
||||
return "1 hour, 1 minute ago"
|
||||
} else if minutes > 0 {
|
||||
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
|
||||
}
|
||||
return "1 hour ago"
|
||||
}
|
||||
if minutes > 0 {
|
||||
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%d hours ago", hours)
|
||||
}
|
||||
|
||||
days := int(duration.Hours()) / 24
|
||||
hours := int(duration.Hours()) % 24
|
||||
if days == 1 {
|
||||
if hours == 1 {
|
||||
return "1 day, 1 hour ago"
|
||||
} else if hours > 0 {
|
||||
return fmt.Sprintf("1 day, %d hours ago", hours)
|
||||
}
|
||||
return "1 day ago"
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%d days, %d hours ago", days, hours)
|
||||
}
|
||||
return fmt.Sprintf("%d days ago", days)
|
||||
}
|
||||
|
||||
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
|
||||
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
|
||||
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
|
||||
}
|
||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||
}
|
||||
|
||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||
|
||||
for i, route := range peer.Networks {
|
||||
peer.Networks[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Networks {
|
||||
peer.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
|
||||
for i, peer := range overview.Peers.Details {
|
||||
peer := peer
|
||||
anonymizePeerDetail(a, &peer)
|
||||
overview.Peers.Details[i] = peer
|
||||
}
|
||||
|
||||
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
|
||||
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
|
||||
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
|
||||
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
|
||||
|
||||
overview.IP = a.AnonymizeIPString(overview.IP)
|
||||
for i, detail := range overview.Relays.Details {
|
||||
detail.URI = a.AnonymizeURI(detail.URI)
|
||||
detail.Error = a.AnonymizeString(detail.Error)
|
||||
overview.Relays.Details[i] = detail
|
||||
}
|
||||
|
||||
for i, nsGroup := range overview.NSServerGroups {
|
||||
for j, domain := range nsGroup.Domains {
|
||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
for j, ns := range nsGroup.Servers {
|
||||
host, port, err := net.SplitHostPort(ns)
|
||||
if err == nil {
|
||||
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
@@ -1,597 +1,11 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
func init() {
|
||||
loc, err := time.LoadLocation("UTC")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
time.Local = loc
|
||||
}
|
||||
|
||||
var resp = &proto.StatusResponse{
|
||||
Status: "Connected",
|
||||
FullStatus: &proto.FullStatus{
|
||||
Peers: []*proto.PeerState{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
Fqdn: "peer-1.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
LocalIceCandidateEndpoint: "",
|
||||
RemoteIceCandidateEndpoint: "",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||
BytesRx: 200,
|
||||
BytesTx: 100,
|
||||
Networks: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: durationpb.New(time.Duration(10000000)),
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
Fqdn: "peer-2.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
||||
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
|
||||
BytesRx: 2000,
|
||||
BytesTx: 1000,
|
||||
Latency: durationpb.New(time.Duration(10000000)),
|
||||
},
|
||||
},
|
||||
ManagementState: &proto.ManagementState{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: &proto.SignalState{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: []*proto.RelayState{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
LocalPeerState: &proto.LocalPeerState{
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
Fqdn: "some-localhost.awesome-domain.com",
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
},
|
||||
DnsServers: []*proto.NSGroupState{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
},
|
||||
DaemonVersion: "0.14.1",
|
||||
}
|
||||
|
||||
var overview = statusOutputOverview{
|
||||
Peers: peersStateOutput{
|
||||
Total: 2,
|
||||
Connected: 2,
|
||||
Details: []peerStateDetailOutput{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
FQDN: "peer-1.awesome-domain.com",
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
|
||||
ConnType: "P2P",
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
||||
TransferReceived: 200,
|
||||
TransferSent: 100,
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Networks: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: time.Duration(10000000),
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
FQDN: "peer-2.awesome-domain.com",
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
|
||||
ConnType: "Relayed",
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "relay",
|
||||
Remote: "prflx",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "10.0.0.1:10001",
|
||||
Remote: "10.0.10.1:10002",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
|
||||
TransferReceived: 2000,
|
||||
TransferSent: 1000,
|
||||
Latency: time.Duration(10000000),
|
||||
},
|
||||
},
|
||||
},
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: "0.14.1",
|
||||
ManagementState: managementStateOutput{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: signalStateOutput{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: relayStateOutput{
|
||||
Total: 2,
|
||||
Available: 1,
|
||||
Details: []relayStateOutputDetail{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
},
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
FQDN: "some-localhost.awesome-domain.com",
|
||||
NSServerGroups: []nsServerGroupStateOutput{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
convertedResult := convertToStatusOutputOverview(resp)
|
||||
|
||||
assert.Equal(t, overview, convertedResult)
|
||||
}
|
||||
|
||||
func TestSortingOfPeers(t *testing.T) {
|
||||
peers := []peerStateDetailOutput{
|
||||
{
|
||||
IP: "192.168.178.104",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.105",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.103",
|
||||
},
|
||||
}
|
||||
|
||||
sortPeersByIP(peers)
|
||||
|
||||
assert.Equal(t, peers[3].IP, "192.168.178.104")
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
jsonString, _ := parseToJSON(overview)
|
||||
|
||||
//@formatter:off
|
||||
expectedJSONString := `
|
||||
{
|
||||
"peers": {
|
||||
"total": 2,
|
||||
"connected": 2,
|
||||
"details": [
|
||||
{
|
||||
"fqdn": "peer-1.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.101",
|
||||
"publicKey": "Pubkey1",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
||||
"connectionType": "P2P",
|
||||
"iceCandidateType": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
||||
"transferReceived": 200,
|
||||
"transferSent": 100,
|
||||
"latency": 10000000,
|
||||
"quantumResistance": false,
|
||||
"routes": [
|
||||
"10.1.0.0/24"
|
||||
],
|
||||
"networks": [
|
||||
"10.1.0.0/24"
|
||||
]
|
||||
},
|
||||
{
|
||||
"fqdn": "peer-2.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.102",
|
||||
"publicKey": "Pubkey2",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
||||
"connectionType": "Relayed",
|
||||
"iceCandidateType": {
|
||||
"local": "relay",
|
||||
"remote": "prflx"
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "10.0.0.1:10001",
|
||||
"remote": "10.0.10.1:10002"
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
||||
"transferReceived": 2000,
|
||||
"transferSent": 1000,
|
||||
"latency": 10000000,
|
||||
"quantumResistance": false,
|
||||
"routes": null,
|
||||
"networks": null
|
||||
}
|
||||
]
|
||||
},
|
||||
"cliVersion": "development",
|
||||
"daemonVersion": "0.14.1",
|
||||
"management": {
|
||||
"url": "my-awesome-management.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"signal": {
|
||||
"url": "my-awesome-signal.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"relays": {
|
||||
"total": 2,
|
||||
"available": 1,
|
||||
"details": [
|
||||
{
|
||||
"uri": "stun:my-awesome-stun.com:3478",
|
||||
"available": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
"available": false,
|
||||
"error": "context: deadline exceeded"
|
||||
}
|
||||
]
|
||||
},
|
||||
"netbirdIp": "192.168.178.100/16",
|
||||
"publicKey": "Some-Pub-Key",
|
||||
"usesKernelInterface": true,
|
||||
"fqdn": "some-localhost.awesome-domain.com",
|
||||
"quantumResistance": false,
|
||||
"quantumResistancePermissive": false,
|
||||
"routes": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"networks": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"dnsServers": [
|
||||
{
|
||||
"servers": [
|
||||
"8.8.8.8:53"
|
||||
],
|
||||
"domains": null,
|
||||
"enabled": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"servers": [
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53"
|
||||
],
|
||||
"domains": [
|
||||
"example.com",
|
||||
"example.net"
|
||||
],
|
||||
"enabled": false,
|
||||
"error": "timeout"
|
||||
}
|
||||
]
|
||||
}`
|
||||
// @formatter:on
|
||||
|
||||
var expectedJSON bytes.Buffer
|
||||
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
|
||||
|
||||
assert.Equal(t, expectedJSON.String(), jsonString)
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := parseToYAML(overview)
|
||||
|
||||
expectedYAML :=
|
||||
`peers:
|
||||
total: 2
|
||||
connected: 2
|
||||
details:
|
||||
- fqdn: peer-1.awesome-domain.com
|
||||
netbirdIp: 192.168.178.101
|
||||
publicKey: Pubkey1
|
||||
status: Connected
|
||||
lastStatusUpdate: 2001-01-01T01:01:01Z
|
||||
connectionType: P2P
|
||||
iceCandidateType:
|
||||
local: ""
|
||||
remote: ""
|
||||
iceCandidateEndpoint:
|
||||
local: ""
|
||||
remote: ""
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
||||
transferReceived: 200
|
||||
transferSent: 100
|
||||
latency: 10ms
|
||||
quantumResistance: false
|
||||
routes:
|
||||
- 10.1.0.0/24
|
||||
networks:
|
||||
- 10.1.0.0/24
|
||||
- fqdn: peer-2.awesome-domain.com
|
||||
netbirdIp: 192.168.178.102
|
||||
publicKey: Pubkey2
|
||||
status: Connected
|
||||
lastStatusUpdate: 2002-02-02T02:02:02Z
|
||||
connectionType: Relayed
|
||||
iceCandidateType:
|
||||
local: relay
|
||||
remote: prflx
|
||||
iceCandidateEndpoint:
|
||||
local: 10.0.0.1:10001
|
||||
remote: 10.0.10.1:10002
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
||||
transferReceived: 2000
|
||||
transferSent: 1000
|
||||
latency: 10ms
|
||||
quantumResistance: false
|
||||
routes: []
|
||||
networks: []
|
||||
cliVersion: development
|
||||
daemonVersion: 0.14.1
|
||||
management:
|
||||
url: my-awesome-management.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
signal:
|
||||
url: my-awesome-signal.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
relays:
|
||||
total: 2
|
||||
available: 1
|
||||
details:
|
||||
- uri: stun:my-awesome-stun.com:3478
|
||||
available: true
|
||||
error: ""
|
||||
- uri: turns:my-awesome-turn.com:443?transport=tcp
|
||||
available: false
|
||||
error: 'context: deadline exceeded'
|
||||
netbirdIp: 192.168.178.100/16
|
||||
publicKey: Some-Pub-Key
|
||||
usesKernelInterface: true
|
||||
fqdn: some-localhost.awesome-domain.com
|
||||
quantumResistance: false
|
||||
quantumResistancePermissive: false
|
||||
routes:
|
||||
- 10.10.0.0/24
|
||||
networks:
|
||||
- 10.10.0.0/24
|
||||
dnsServers:
|
||||
- servers:
|
||||
- 8.8.8.8:53
|
||||
domains: []
|
||||
enabled: true
|
||||
error: ""
|
||||
- servers:
|
||||
- 1.1.1.1:53
|
||||
- 2.2.2.2:53
|
||||
domains:
|
||||
- example.com
|
||||
- example.net
|
||||
enabled: false
|
||||
error: timeout
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
}
|
||||
|
||||
func TestParsingToDetail(t *testing.T) {
|
||||
// Calculate time ago based on the fixture dates
|
||||
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
|
||||
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
|
||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||
|
||||
detail := parseToFullDetailSummary(overview)
|
||||
|
||||
expectedDetail := fmt.Sprintf(
|
||||
`Peers detail:
|
||||
peer-1.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.101
|
||||
Public key: Pubkey1
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: P2P
|
||||
ICE candidate (Local/Remote): -/-
|
||||
ICE candidate endpoints (Local/Remote): -/-
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 200 B/100 B
|
||||
Quantum resistance: false
|
||||
Routes: 10.1.0.0/24
|
||||
Networks: 10.1.0.0/24
|
||||
Latency: 10ms
|
||||
|
||||
peer-2.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.102
|
||||
Public key: Pubkey2
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: Relayed
|
||||
ICE candidate (Local/Remote): relay/prflx
|
||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||
Quantum resistance: false
|
||||
Routes: -
|
||||
Networks: -
|
||||
Latency: 10ms
|
||||
|
||||
OS: %s/%s
|
||||
Daemon version: 0.14.1
|
||||
CLI version: %s
|
||||
Management: Connected to my-awesome-management.com:443
|
||||
Signal: Connected to my-awesome-signal.com:443
|
||||
Relays:
|
||||
[stun:my-awesome-stun.com:3478] is Available
|
||||
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
|
||||
Nameservers:
|
||||
[8.8.8.8:53] for [.] is Available
|
||||
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
|
||||
assert.Equal(t, expectedDetail, detail)
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := parseGeneralSummary(overview, false, false, false)
|
||||
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
Management: Connected
|
||||
Signal: Connected
|
||||
Relays: 1/2 Available
|
||||
Nameservers: 1/2 Available
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedString, shortVersion)
|
||||
}
|
||||
|
||||
func TestParsingOfIP(t *testing.T) {
|
||||
InterfaceIP := "192.168.178.123/16"
|
||||
|
||||
@@ -599,31 +13,3 @@ func TestParsingOfIP(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
||||
}
|
||||
|
||||
func TestTimeAgo(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input time.Time
|
||||
expected string
|
||||
}{
|
||||
{"Now", now, "Now"},
|
||||
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
|
||||
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
|
||||
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
|
||||
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
|
||||
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
|
||||
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
|
||||
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
|
||||
{"Zero time", time.Time{}, "-"},
|
||||
{"Unix zero time", time.Unix(0, 0), "-"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := timeAgo(tc.input)
|
||||
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
137
client/cmd/trace.go
Normal file
137
client/cmd/trace.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var traceCmd = &cobra.Command{
|
||||
Use: "trace <direction> <source-ip> <dest-ip>",
|
||||
Short: "Trace a packet through the firewall",
|
||||
Example: `
|
||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||
Args: cobra.ExactArgs(3),
|
||||
RunE: tracePacket,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.AddCommand(traceCmd)
|
||||
|
||||
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
|
||||
traceCmd.Flags().Uint16("sport", 0, "Source port")
|
||||
traceCmd.Flags().Uint16("dport", 0, "Destination port")
|
||||
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
|
||||
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
|
||||
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
|
||||
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
|
||||
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
|
||||
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
|
||||
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
|
||||
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
|
||||
}
|
||||
|
||||
func tracePacket(cmd *cobra.Command, args []string) error {
|
||||
direction := strings.ToLower(args[0])
|
||||
if direction != "in" && direction != "out" {
|
||||
return fmt.Errorf("invalid direction: use 'in' or 'out'")
|
||||
}
|
||||
|
||||
protocol := cmd.Flag("protocol").Value.String()
|
||||
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
|
||||
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
|
||||
}
|
||||
|
||||
sport, err := cmd.Flags().GetUint16("sport")
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid source port: %v", err)
|
||||
}
|
||||
dport, err := cmd.Flags().GetUint16("dport")
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination port: %v", err)
|
||||
}
|
||||
|
||||
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
|
||||
if protocol != "icmp" {
|
||||
if sport == 0 {
|
||||
sport = uint16(rand.Intn(16383) + 49152)
|
||||
}
|
||||
if dport == 0 {
|
||||
dport = uint16(rand.Intn(16383) + 49152)
|
||||
}
|
||||
}
|
||||
|
||||
var tcpFlags *proto.TCPFlags
|
||||
if protocol == "tcp" {
|
||||
syn, _ := cmd.Flags().GetBool("syn")
|
||||
ack, _ := cmd.Flags().GetBool("ack")
|
||||
fin, _ := cmd.Flags().GetBool("fin")
|
||||
rst, _ := cmd.Flags().GetBool("rst")
|
||||
psh, _ := cmd.Flags().GetBool("psh")
|
||||
urg, _ := cmd.Flags().GetBool("urg")
|
||||
|
||||
tcpFlags = &proto.TCPFlags{
|
||||
Syn: syn,
|
||||
Ack: ack,
|
||||
Fin: fin,
|
||||
Rst: rst,
|
||||
Psh: psh,
|
||||
Urg: urg,
|
||||
}
|
||||
}
|
||||
|
||||
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
|
||||
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
|
||||
SourceIp: args[1],
|
||||
DestinationIp: args[2],
|
||||
Protocol: protocol,
|
||||
SourcePort: uint32(sport),
|
||||
DestinationPort: uint32(dport),
|
||||
Direction: direction,
|
||||
TcpFlags: tcpFlags,
|
||||
IcmpType: &icmpType,
|
||||
IcmpCode: &icmpCode,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||
|
||||
for _, stage := range resp.Stages {
|
||||
if stage.ForwardingDetails != nil {
|
||||
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
|
||||
} else {
|
||||
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
|
||||
}
|
||||
}
|
||||
|
||||
disposition := map[bool]string{
|
||||
true: "\033[32mALLOWED\033[0m", // Green
|
||||
false: "\033[31mDENIED\033[0m", // Red
|
||||
}[resp.FinalDisposition]
|
||||
|
||||
cmd.Printf("\nFinal disposition: %s\n", disposition)
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -29,9 +30,16 @@ const (
|
||||
interfaceInputType
|
||||
)
|
||||
|
||||
const (
|
||||
dnsLabelsFlag = "extra-dns-labels"
|
||||
)
|
||||
|
||||
var (
|
||||
foregroundMode bool
|
||||
upCmd = &cobra.Command{
|
||||
foregroundMode bool
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
Short: "install, login and start Netbird client",
|
||||
RunE: upFunc,
|
||||
@@ -49,6 +57,14 @@ func init() {
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels`+
|
||||
`You can specify a comma-separated list of up to 32 labels. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`,
|
||||
)
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -67,6 +83,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
if hostName != "" {
|
||||
@@ -98,6 +119,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
DNSLabels: dnsLabelsValidated,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
@@ -190,7 +212,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
return connectClient.Run()
|
||||
return connectClient.Run(nil)
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
@@ -240,6 +262,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
DnsLabels: dnsLabels,
|
||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@@ -430,6 +454,24 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func validateDnsLabels(labels []string) (domain.List, error) {
|
||||
var (
|
||||
domains domain.List
|
||||
err error
|
||||
)
|
||||
|
||||
if len(labels) == 0 {
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
domains, err = domain.ValidateDomains(labels)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func isValidAddrPort(input string) bool {
|
||||
if input == "" {
|
||||
return true
|
||||
|
||||
167
client/embed/doc.go
Normal file
167
client/embed/doc.go
Normal file
@@ -0,0 +1,167 @@
|
||||
// Package embed provides a way to embed the NetBird client directly
|
||||
// into Go programs without requiring a separate NetBird client installation.
|
||||
package embed
|
||||
|
||||
// Basic Usage:
|
||||
//
|
||||
// client, err := embed.New(embed.Options{
|
||||
// DeviceName: "my-service",
|
||||
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
// if err := client.Start(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// Complete HTTP Server Example:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "log"
|
||||
// "net/http"
|
||||
// "os"
|
||||
// "os/signal"
|
||||
// "syscall"
|
||||
// "time"
|
||||
//
|
||||
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// // Create client with setup key and device name
|
||||
// client, err := netbird.New(netbird.Options{
|
||||
// DeviceName: "http-server",
|
||||
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||
// LogOutput: io.Discard,
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Start with timeout
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
// if err := client.Start(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Create HTTP server
|
||||
// mux := http.NewServeMux()
|
||||
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
|
||||
// fmt.Fprintf(w, "Hello from netbird!")
|
||||
// })
|
||||
//
|
||||
// // Listen on netbird network
|
||||
// l, err := client.ListenTCP(":8080")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// server := &http.Server{Handler: mux}
|
||||
// go func() {
|
||||
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
|
||||
// log.Printf("HTTP server error: %v", err)
|
||||
// }
|
||||
// }()
|
||||
//
|
||||
// log.Printf("HTTP server listening on netbird network port 8080")
|
||||
//
|
||||
// // Handle shutdown
|
||||
// stop := make(chan os.Signal, 1)
|
||||
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||
// <-stop
|
||||
//
|
||||
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
// log.Printf("HTTP shutdown error: %v", err)
|
||||
// }
|
||||
// if err := client.Stop(shutdownCtx); err != nil {
|
||||
// log.Printf("Netbird shutdown error: %v", err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Complete HTTP Client Example:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "io"
|
||||
// "log"
|
||||
// "os"
|
||||
// "time"
|
||||
//
|
||||
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// // Create client with setup key and device name
|
||||
// client, err := netbird.New(netbird.Options{
|
||||
// DeviceName: "http-client",
|
||||
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||
// LogOutput: io.Discard,
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Start with timeout
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// if err := client.Start(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Create HTTP client that uses netbird network
|
||||
// httpClient := client.NewHTTPClient()
|
||||
// httpClient.Timeout = 10 * time.Second
|
||||
//
|
||||
// // Make request to server in netbird network
|
||||
// target := os.Getenv("NB_TARGET")
|
||||
// resp, err := httpClient.Get(target)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
//
|
||||
// // Read and print response
|
||||
// body, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// fmt.Printf("Response from server: %s\n", string(body))
|
||||
//
|
||||
// // Clean shutdown
|
||||
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// if err := client.Stop(shutdownCtx); err != nil {
|
||||
// log.Printf("Netbird shutdown error: %v", err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// The package provides several methods for network operations:
|
||||
// - Dial: Creates outbound connections
|
||||
// - ListenTCP: Creates TCP listeners
|
||||
// - ListenUDP: Creates UDP listeners
|
||||
//
|
||||
// By default, the embed package uses userspace networking mode, which doesn't
|
||||
// require root/admin privileges. For production deployments, consider setting
|
||||
// appropriate config and state paths for persistence.
|
||||
296
client/embed/embed.go
Normal file
296
client/embed/embed.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
|
||||
// Client manages a netbird embedded client instance
|
||||
type Client struct {
|
||||
deviceName string
|
||||
config *internal.Config
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
setupKey string
|
||||
connect *internal.ConnectClient
|
||||
}
|
||||
|
||||
// Options configures a new Client
|
||||
type Options struct {
|
||||
// DeviceName is this peer's name in the network
|
||||
DeviceName string
|
||||
// SetupKey is used for authentication
|
||||
SetupKey string
|
||||
// ManagementURL overrides the default management server URL
|
||||
ManagementURL string
|
||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||
PreSharedKey string
|
||||
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
||||
LogOutput io.Writer
|
||||
// LogLevel sets the logging level (defaults to info if empty)
|
||||
LogLevel string
|
||||
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
|
||||
NoUserspace bool
|
||||
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
|
||||
ConfigPath string
|
||||
// StatePath is the path to the netbird state file
|
||||
StatePath string
|
||||
// DisableClientRoutes disables the client routes
|
||||
DisableClientRoutes bool
|
||||
}
|
||||
|
||||
// New creates a new netbird embedded client
|
||||
func New(opts Options) (*Client, error) {
|
||||
if opts.LogOutput != nil {
|
||||
logrus.SetOutput(opts.LogOutput)
|
||||
}
|
||||
|
||||
if opts.LogLevel != "" {
|
||||
level, err := logrus.ParseLevel(opts.LogLevel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse log level: %w", err)
|
||||
}
|
||||
logrus.SetLevel(level)
|
||||
}
|
||||
|
||||
if !opts.NoUserspace {
|
||||
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
|
||||
return nil, fmt.Errorf("setenv: %w", err)
|
||||
}
|
||||
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
|
||||
return nil, fmt.Errorf("setenv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.StatePath != "" {
|
||||
// TODO: Disable state if path not provided
|
||||
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
|
||||
return nil, fmt.Errorf("setenv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
t := true
|
||||
var config *internal.Config
|
||||
var err error
|
||||
input := internal.ConfigInput{
|
||||
ConfigPath: opts.ConfigPath,
|
||||
ManagementURL: opts.ManagementURL,
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = internal.UpdateOrCreateConfig(input)
|
||||
} else {
|
||||
config, err = internal.CreateInMemoryConfig(input)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create config: %w", err)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
|
||||
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
|
||||
func (c *Client) Start(startCtx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cancel != nil {
|
||||
return ErrClientAlreadyStarted
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
run <- err
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
return startCtx.Err()
|
||||
case err := <-run:
|
||||
if err != nil {
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
|
||||
}
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the client.
|
||||
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
|
||||
func (c *Client) Stop(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.connect == nil {
|
||||
return ErrClientNotStarted
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.connect.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancel = nil
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
c.cancel = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve: %w", err)
|
||||
}
|
||||
return nsnet.ListenTCP(tcpAddr)
|
||||
}
|
||||
|
||||
// ListenUDP listens on the given address in the netbird network
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve: %w", err)
|
||||
}
|
||||
|
||||
return nsnet.ListenUDP(udpAddr)
|
||||
}
|
||||
|
||||
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) NewHTTPClient() *http.Client {
|
||||
transport := &http.Transport{
|
||||
DialContext: c.Dial,
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, netip.Addr{}, errors.New("client not started")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, netip.Addr{}, errors.New("engine not started")
|
||||
}
|
||||
|
||||
addr, err := engine.Address()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
return nsnet, addr, nil
|
||||
}
|
||||
@@ -14,13 +14,13 @@ import (
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface)
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
fm, err := createNativeFirewall(iface, stateManager)
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
@@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm)
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
@@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -10,4 +12,6 @@ type IFaceMapper interface {
|
||||
Address() device.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
}
|
||||
|
||||
@@ -213,6 +213,19 @@ func (m *Manager) AllowNetbird() error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
func (m *Manager) SetLogLevel(log.Level) {
|
||||
// not supported
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering(
|
||||
}
|
||||
|
||||
rule := genRouteFilteringRuleSpec(params)
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
// after the established rule
|
||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
|
||||
} else {
|
||||
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add route rule: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -99,6 +99,12 @@ type Manager interface {
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
SetLogLevel(log.Level)
|
||||
|
||||
EnableRouting() error
|
||||
|
||||
DisableRouting() error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -348,6 +348,10 @@ func (m *AclManager) addIOFiltering(
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
nftRule: nftRule,
|
||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||
@@ -359,6 +363,7 @@ func (m *AclManager) addIOFiltering(
|
||||
if ipset != nil {
|
||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -318,6 +318,19 @@ func (m *Manager) cleanupNetbirdTables() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
func (m *Manager) SetLogLevel(log.Level) {
|
||||
// not supported
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
|
||||
@@ -107,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
@@ -307,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||
t.Helper()
|
||||
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||
|
||||
for i := range got {
|
||||
if _, isCounter := got[i].(*expr.Counter); isCounter {
|
||||
_, wantIsCounter := want[i].(*expr.Counter)
|
||||
require.True(t, wantIsCounter, "expected Counter at index %d", i)
|
||||
continue
|
||||
}
|
||||
|
||||
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering(
|
||||
UserData: []byte(ruleKey),
|
||||
}
|
||||
|
||||
rule = r.conn.AddRule(rule)
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
if action == firewall.ActionDrop {
|
||||
// TODO: Insert after the established rule
|
||||
rule = r.conn.InsertRule(rule)
|
||||
} else {
|
||||
rule = r.conn.AddRule(rule)
|
||||
}
|
||||
|
||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := m.logger.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to shutdown logger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
|
||||
16
client/firewall/uspfilter/common/iface.go
Normal file
16
client/firewall/uspfilter/common/iface.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetDevice() *device.FilteredDevice
|
||||
}
|
||||
@@ -10,12 +10,11 @@ import (
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
established atomic.Bool
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
@@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
||||
b.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (b *BaseConnTrack) IsEstablished() bool {
|
||||
return b.established.Load()
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (b *BaseConnTrack) SetEstablished(state bool) {
|
||||
b.established.Store(state)
|
||||
}
|
||||
|
||||
// GetLastSeen safely gets the last seen timestamp
|
||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||
return time.Unix(0, b.lastSeen.Load())
|
||||
|
||||
@@ -3,8 +3,14 @@ package conntrack
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
|
||||
func BenchmarkIPOperations(b *testing.B) {
|
||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
@@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) {
|
||||
})
|
||||
|
||||
}
|
||||
func BenchmarkAtomicOperations(b *testing.B) {
|
||||
conn := &BaseConnTrack{}
|
||||
b.Run("UpdateLastSeen", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsEstablished", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = conn.IsEstablished()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("SetEstablished", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.SetEstablished(i%2 == 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetLastSeen", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = conn.GetLastSeen()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
@@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,6 +35,7 @@ type ICMPConnTrack struct {
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
type ICMPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ICMPConnKey]*ICMPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
@@ -42,12 +45,13 @@ type ICMPTracker struct {
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
|
||||
tracker := &ICMPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
@@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
// TrackOutbound records an outbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New ICMP connection %v", key)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
switch icmpType {
|
||||
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
||||
uint8(layers.ICMPv4TypeTimeExceeded):
|
||||
return true
|
||||
case uint8(layers.ICMPv4TypeEchoReply):
|
||||
// continue processing
|
||||
default:
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
|
||||
return false
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.ID == id &&
|
||||
conn.Sequence == seq
|
||||
@@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
|
||||
@@ -5,7 +5,10 @@ package conntrack
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -61,12 +64,24 @@ type TCPConnKey struct {
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
State TCPState
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (t *TCPConnTrack) IsEstablished() bool {
|
||||
return t.established.Load()
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||
t.established.Store(state)
|
||||
}
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
@@ -76,8 +91,9 @@ type TCPTracker struct {
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
@@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
// Create key before lock
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
},
|
||||
State: TCPStateNew,
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
@@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, true)
|
||||
conn.Unlock()
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
@@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
@@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
|
||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
|
||||
case TCPStateTimeWait:
|
||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||
// This is handled by the cleanup routine
|
||||
|
||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,6 +22,7 @@ type UDPConnTrack struct {
|
||||
|
||||
// UDPTracker manages UDP connection states
|
||||
type UDPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*UDPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
@@ -29,12 +32,13 @@ type UDPTracker struct {
|
||||
}
|
||||
|
||||
// NewUDPTracker creates a new UDP connection tracker
|
||||
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
|
||||
tracker := &UDPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
@@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
@@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
DestPort: dstPort,
|
||||
},
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New UDP connection: %v", conn)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.lastSeen.Store(now)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
@@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
return false
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.DestPort == srcPort &&
|
||||
conn.SourcePort == dstPort
|
||||
@@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewUDPTracker(tt.timeout)
|
||||
tracker := NewUDPTracker(tt.timeout, logger)
|
||||
assert.NotNil(t, tracker)
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
@@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||
assert.Equal(t, srcPort, conn.SourcePort)
|
||||
assert.Equal(t, dstPort, conn.DestPort)
|
||||
assert.True(t, conn.IsEstablished())
|
||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||
}
|
||||
|
||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1 * time.Second)
|
||||
tracker := NewUDPTracker(1*time.Second, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
@@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
|
||||
func BenchmarkUDPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
|
||||
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
81
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||
type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu uint32
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
e.dispatcher = dispatcher
|
||||
}
|
||||
|
||||
func (e *endpoint) IsAttached() bool {
|
||||
return e.dispatcher != nil
|
||||
}
|
||||
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
return e.mtu
|
||||
}
|
||||
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return stack.CapabilityNone
|
||||
}
|
||||
|
||||
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet through WireGuard
|
||||
address := netHeader.DestinationAddress()
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
e.logger.Error("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (e *endpoint) Wait() {
|
||||
// not required
|
||||
}
|
||||
|
||||
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
return header.ARPHardwareNone
|
||||
}
|
||||
|
||||
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||
// not required
|
||||
}
|
||||
|
||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
166
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultReceiveWindow = 32768
|
||||
defaultMaxInFlight = 1024
|
||||
iosReceiveWindow = 16384
|
||||
iosMaxInFlight = 256
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip net.IP
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
mtu, err := iface.GetDevice().MTU()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get MTU: %w", err)
|
||||
}
|
||||
nicID := tcpip.NICID(1)
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
ones, _ := iface.Address().Network.Mask.Size()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||
PrefixLen: ones,
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set spoofing: %s", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: defaultSubnet,
|
||||
NIC: nicID,
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: iface.Address().IP,
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
maxInFlight := defaultMaxInFlight
|
||||
if runtime.GOOS == "ios" {
|
||||
receiveWindow = iosReceiveWindow
|
||||
maxInFlight = iosMaxInFlight
|
||||
}
|
||||
|
||||
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
|
||||
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
|
||||
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
if f.endpoint.dispatcher != nil {
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the forwarder
|
||||
func (f *Forwarder) Stop() {
|
||||
f.cancel()
|
||||
|
||||
if f.udpForwarder != nil {
|
||||
f.udpForwarder.Stop()
|
||||
}
|
||||
|
||||
f.stack.Close()
|
||||
f.stack.Wait()
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
return addr.AsSlice()
|
||||
}
|
||||
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
109
client/firewall/uspfilter/forwarder/icmp.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
// Get the complete ICMP message (header + data)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
switch icmpHdr.Type() {
|
||||
case header.ICMPv4Echo:
|
||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||
case header.ICMPv4EchoReply:
|
||||
// dont process our own replies
|
||||
return true
|
||||
default:
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||
_, err = conn.WriteTo(payload, dst)
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||
return true
|
||||
}
|
||||
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
90
client/firewall/uspfilter/forwarder/tcp.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
|
||||
// Complete the handshake
|
||||
r.Complete(false)
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
// Create context for managing the proxy goroutines
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(outConn, inConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(inConn, outConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
||||
return
|
||||
}
|
||||
}
|
||||
284
client/firewall/uspfilter/forwarder/udp.go
Normal file
284
client/firewall/uspfilter/forwarder/udp.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
const (
|
||||
udpTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type udpPacketConn struct {
|
||||
conn *gonet.UDPConn
|
||||
outConn net.Conn
|
||||
lastSeen atomic.Int64
|
||||
cancel context.CancelFunc
|
||||
ep tcpip.Endpoint
|
||||
}
|
||||
|
||||
type udpForwarder struct {
|
||||
sync.RWMutex
|
||||
logger *nblog.Logger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type idleConn struct {
|
||||
id stack.TransportEndpointID
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, mtu)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
go f.cleanup()
|
||||
return f
|
||||
}
|
||||
|
||||
// Stop stops the UDP forwarder and all active connections
|
||||
func (f *udpForwarder) Stop() {
|
||||
f.cancel()
|
||||
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
|
||||
for id, conn := range f.conns {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
|
||||
conn.ep.Close()
|
||||
delete(f.conns, id)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup periodically removes idle UDP connections
|
||||
func (f *udpForwarder) cleanup() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
var idleConns []idleConn
|
||||
|
||||
f.RLock()
|
||||
for id, conn := range f.conns {
|
||||
if conn.getIdleDuration() > udpTimeout {
|
||||
idleConns = append(idleConns, idleConn{id, conn})
|
||||
}
|
||||
}
|
||||
f.RUnlock()
|
||||
|
||||
for _, idle := range idleConns {
|
||||
idle.conn.cancel()
|
||||
if err := idle.conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
if err := idle.conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
|
||||
idle.conn.ep.Close()
|
||||
|
||||
f.Lock()
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if f.ctx.Err() != nil {
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
id := r.ID()
|
||||
|
||||
f.udpForwarder.RLock()
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||
return
|
||||
}
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
|
||||
pConn := &udpPacketConn{
|
||||
conn: inConn,
|
||||
outConn: outConn,
|
||||
cancel: connCancel,
|
||||
ep: ep,
|
||||
}
|
||||
pConn.updateLastSeen()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
// Double-check no connection was created while we were setting up
|
||||
if _, exists := f.udpForwarder.conns[id]; exists {
|
||||
f.udpForwarder.Unlock()
|
||||
pConn.cancel()
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
defer func() {
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
}()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) updateLastSeen() {
|
||||
c.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||
lastSeen := time.Unix(0, c.lastSeen.Load())
|
||||
return time.Since(lastSeen)
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||
bufp := bufPool.Get().(*[]byte)
|
||||
defer bufPool.Put(bufp)
|
||||
buffer := *bufp
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read from %s: %w", direction, err)
|
||||
}
|
||||
|
||||
_, err = dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to %s: %w", direction, err)
|
||||
}
|
||||
|
||||
c.updateLastSeen()
|
||||
}
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
134
client/firewall/uspfilter/localip.go
Normal file
134
client/firewall/uspfilter/localip.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
)
|
||||
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||
ipv4Bitmap [1 << 16]uint32
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
return &localIPManager{}
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return false
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
if int(high) >= len(*newIPv4Bitmap) {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
}
|
||||
ipStr := ip.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [1 << 16]uint32
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
|
||||
// 127.0.0.0/8
|
||||
high := uint16(127) << 8
|
||||
for i := uint16(0); i < 256; i++ {
|
||||
newIPv4Bitmap[high|i] = 0xffffffff
|
||||
}
|
||||
|
||||
if iface != nil {
|
||||
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.ipv4Bitmap = newIPv4Bitmap
|
||||
m.mu.Unlock()
|
||||
|
||||
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return m.checkBitmapBit(ipv4)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
270
client/firewall/uspfilter/localip_test.go
Normal file
270
client/firewall/uspfilter/localip_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
func TestLocalIPManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAddr iface.WGAddress
|
||||
testIP net.IP
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.2"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.255.255.255"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("fe80::1"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return tt.setupAddr
|
||||
},
|
||||
}
|
||||
|
||||
err := manager.UpdateLocalIPs(mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.IsLocalIP(tt.testIP)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
mock := &IFaceMock{}
|
||||
|
||||
// Get actual local interfaces
|
||||
interfaces, err := net.Interfaces()
|
||||
require.NoError(t, err)
|
||||
|
||||
var tests []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}
|
||||
|
||||
// Add all local interface IPs to test cases
|
||||
for _, iface := range interfaces {
|
||||
addrs, err := iface.Addrs()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
tests = append(tests, struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
ip: ip4.String(),
|
||||
expected: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add some external IPs as negative test cases
|
||||
externalIPs := []string{
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"208.67.222.222",
|
||||
}
|
||||
for _, ip := range externalIPs {
|
||||
tests = append(tests, struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
ip: ip,
|
||||
expected: false,
|
||||
})
|
||||
}
|
||||
|
||||
require.NotEmpty(t, tests, "No test cases generated")
|
||||
|
||||
err = manager.UpdateLocalIPs(mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Testing %d IPs", len(tests))
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MapImplementation is a version using map[string]struct{}
|
||||
type MapImplementation struct {
|
||||
localIPs map[string]struct{}
|
||||
}
|
||||
|
||||
func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces := make([]net.IP, 16)
|
||||
for i := range interfaces {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap version
|
||||
bitmapManager := &localIPManager{
|
||||
ipv4Bitmap: [1 << 16]uint32{},
|
||||
}
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
|
||||
// Setup map version
|
||||
mapManager := &MapImplementation{
|
||||
localIPs: make(map[string]struct{}),
|
||||
}
|
||||
for _, ip := range interfaces[:8] {
|
||||
mapManager.localIPs[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bitmapManager.checkBitmapBit(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Hit", func(b *testing.B) {
|
||||
ip := interfaces[4]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Map_Miss", func(b *testing.B) {
|
||||
ip := interfaces[12]
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkWGPosition(b *testing.B) {
|
||||
wgIP := net.ParseIP("10.10.0.1")
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
}
|
||||
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bm.checkBitmapBit(wgIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
196
client/firewall/uspfilter/log/log.go
Normal file
196
client/firewall/uspfilter/log/log.go
Normal file
@@ -0,0 +1,196 @@
|
||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||
maxMessageSize = 1024 * 2 // 2KB per message
|
||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
)
|
||||
|
||||
// Level represents log severity
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
LevelPanic Level = iota
|
||||
LevelFatal
|
||||
LevelError
|
||||
LevelWarn
|
||||
LevelInfo
|
||||
LevelDebug
|
||||
LevelTrace
|
||||
)
|
||||
|
||||
var levelStrings = map[Level]string{
|
||||
LevelPanic: "PANC",
|
||||
LevelFatal: "FATL",
|
||||
LevelError: "ERRO",
|
||||
LevelWarn: "WARN",
|
||||
LevelInfo: "INFO",
|
||||
LevelDebug: "DEBG",
|
||||
LevelTrace: "TRAC",
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
type Logger struct {
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
buffer *ringBuffer
|
||||
shutdown chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Reusable buffer pool for formatting messages
|
||||
bufPool sync.Pool
|
||||
}
|
||||
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
buffer: newRingBuffer(bufferSize),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// Pre-allocate buffer for message formatting
|
||||
b := make([]byte, 0, maxMessageSize)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
logrusLevel := logrusLogger.GetLevel()
|
||||
l.level.Store(uint32(logrusLevel))
|
||||
level := levelStrings[Level(logrusLevel)]
|
||||
log.Debugf("New uspfilter logger created with loglevel %v", level)
|
||||
|
||||
l.wg.Add(1)
|
||||
go l.worker()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *Logger) SetLevel(level Level) {
|
||||
l.level.Store(uint32(level))
|
||||
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||
*buf = (*buf)[:0]
|
||||
|
||||
// Timestamp
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Level
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Message
|
||||
if len(args) > 0 {
|
||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||
} else {
|
||||
*buf = append(*buf, format...)
|
||||
}
|
||||
|
||||
*buf = append(*buf, '\n')
|
||||
}
|
||||
|
||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
l.formatMessage(bufp, level, format, args...)
|
||||
|
||||
if len(*bufp) > maxMessageSize {
|
||||
*bufp = (*bufp)[:maxMessageSize]
|
||||
}
|
||||
_, _ = l.buffer.Write(*bufp)
|
||||
|
||||
l.bufPool.Put(bufp)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
l.log(LevelError, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
l.log(LevelWarn, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelInfo) {
|
||||
l.log(LevelInfo, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
l.log(LevelDebug, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
l.log(LevelTrace, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// worker periodically flushes the buffer
|
||||
func (l *Logger) worker() {
|
||||
defer l.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(defaultFlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
buf := make([]byte, 0, maxBatchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.shutdown:
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Read accumulated messages
|
||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Write batch
|
||||
_, _ = l.output.Write(buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the logger
|
||||
func (l *Logger) Stop(ctx context.Context) error {
|
||||
done := make(chan struct{})
|
||||
|
||||
l.closeOnce.Do(func() {
|
||||
close(l.shutdown)
|
||||
})
|
||||
|
||||
go func() {
|
||||
l.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package log
|
||||
|
||||
import "sync"
|
||||
|
||||
// ringBuffer is a simple ring buffer implementation
|
||||
type ringBuffer struct {
|
||||
buf []byte
|
||||
size int
|
||||
r, w int64 // Read and write positions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newRingBuffer(size int) *ringBuffer {
|
||||
return &ringBuffer{
|
||||
buf: make([]byte, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(p) > r.size {
|
||||
p = p[:r.size]
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
|
||||
// Write data, handling wrap-around
|
||||
pos := int(r.w % int64(r.size))
|
||||
writeLen := min(len(p), r.size-pos)
|
||||
copy(r.buf[pos:], p[:writeLen])
|
||||
|
||||
// If we have more data and need to wrap around
|
||||
if writeLen < len(p) {
|
||||
copy(r.buf, p[writeLen:])
|
||||
}
|
||||
|
||||
// Update write position
|
||||
r.w += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.w == r.r {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Calculate available data accounting for wraparound
|
||||
available := int(r.w - r.r)
|
||||
if available < 0 {
|
||||
available += r.size
|
||||
}
|
||||
available = min(available, r.size)
|
||||
|
||||
// Limit read to buffer size
|
||||
toRead := min(available, len(p))
|
||||
if toRead == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read data, handling wrap-around
|
||||
pos := int(r.r % int64(r.size))
|
||||
readLen := min(toRead, r.size-pos)
|
||||
n = copy(p, r.buf[pos:pos+readLen])
|
||||
|
||||
// If we need more data and need to wrap around
|
||||
if readLen < toRead {
|
||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||
}
|
||||
|
||||
// Update read position
|
||||
r.r += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
@@ -2,14 +2,15 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
ip net.IP
|
||||
ipLayer gopacket.LayerType
|
||||
@@ -24,6 +25,21 @@ type Rule struct {
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
func (r *PeerRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *RouteRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
390
client/firewall/uspfilter/tracer.go
Normal file
390
client/firewall/uspfilter/tracer.go
Normal file
@@ -0,0 +1,390 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
)
|
||||
|
||||
type PacketStage int
|
||||
|
||||
const (
|
||||
StageReceived PacketStage = iota
|
||||
StageConntrack
|
||||
StagePeerACL
|
||||
StageRouting
|
||||
StageRouteACL
|
||||
StageForwarding
|
||||
StageCompleted
|
||||
)
|
||||
|
||||
const msgProcessingCompleted = "Processing completed"
|
||||
|
||||
func (s PacketStage) String() string {
|
||||
return map[PacketStage]string{
|
||||
StageReceived: "Received",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
}[s]
|
||||
}
|
||||
|
||||
type ForwarderAction struct {
|
||||
Action string
|
||||
RemoteAddr string
|
||||
Error error
|
||||
}
|
||||
|
||||
type TraceResult struct {
|
||||
Timestamp time.Time
|
||||
Stage PacketStage
|
||||
Message string
|
||||
Allowed bool
|
||||
ForwarderAction *ForwarderAction
|
||||
}
|
||||
|
||||
type PacketTrace struct {
|
||||
SourceIP net.IP
|
||||
DestinationIP net.IP
|
||||
Protocol string
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
Direction fw.RuleDirection
|
||||
Results []TraceResult
|
||||
}
|
||||
|
||||
type TCPState struct {
|
||||
SYN bool
|
||||
ACK bool
|
||||
FIN bool
|
||||
RST bool
|
||||
PSH bool
|
||||
URG bool
|
||||
}
|
||||
|
||||
type PacketBuilder struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
Protocol fw.Protocol
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
Direction fw.RuleDirection
|
||||
PayloadSize int
|
||||
TCPState *TCPState
|
||||
}
|
||||
|
||||
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
|
||||
t.Results = append(t.Results, TraceResult{
|
||||
Timestamp: time.Now(),
|
||||
Stage: stage,
|
||||
Message: message,
|
||||
Allowed: allowed,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
|
||||
t.Results = append(t.Results, TraceResult{
|
||||
Timestamp: time.Now(),
|
||||
Stage: stage,
|
||||
Message: message,
|
||||
Allowed: allowed,
|
||||
ForwarderAction: action,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||
ip := p.buildIPLayer()
|
||||
pktLayers := []gopacket.SerializableLayer{ip}
|
||||
|
||||
transportLayer, err := p.buildTransportLayer(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pktLayers = append(pktLayers, transportLayer...)
|
||||
|
||||
if p.PayloadSize > 0 {
|
||||
payload := make([]byte, p.PayloadSize)
|
||||
pktLayers = append(pktLayers, gopacket.Payload(payload))
|
||||
}
|
||||
|
||||
return serializePacket(pktLayers)
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
return &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
SrcIP: p.SrcIP,
|
||||
DstIP: p.DstIP,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
switch p.Protocol {
|
||||
case "tcp":
|
||||
return p.buildTCPLayer(ip)
|
||||
case "udp":
|
||||
return p.buildUDPLayer(ip)
|
||||
case "icmp":
|
||||
return p.buildICMPLayer()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(p.SrcPort),
|
||||
DstPort: layers.TCPPort(p.DstPort),
|
||||
Window: 65535,
|
||||
SYN: p.TCPState != nil && p.TCPState.SYN,
|
||||
ACK: p.TCPState != nil && p.TCPState.ACK,
|
||||
FIN: p.TCPState != nil && p.TCPState.FIN,
|
||||
RST: p.TCPState != nil && p.TCPState.RST,
|
||||
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||
URG: p.TCPState != nil && p.TCPState.URG,
|
||||
}
|
||||
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||
}
|
||||
return []gopacket.SerializableLayer{tcp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(p.SrcPort),
|
||||
DstPort: layers.UDPPort(p.DstPort),
|
||||
}
|
||||
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||
}
|
||||
return []gopacket.SerializableLayer{udp}, nil
|
||||
}
|
||||
|
||||
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||
}
|
||||
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
|
||||
icmp.Id = uint16(1)
|
||||
icmp.Seq = uint16(1)
|
||||
}
|
||||
return []gopacket.SerializableLayer{icmp}, nil
|
||||
}
|
||||
|
||||
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
|
||||
return nil, fmt.Errorf("serialize packet: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||
switch protocol {
|
||||
case fw.ProtocolTCP:
|
||||
return int(layers.IPProtocolTCP)
|
||||
case fw.ProtocolUDP:
|
||||
return int(layers.IPProtocolUDP)
|
||||
case fw.ProtocolICMP:
|
||||
return int(layers.IPProtocolICMPv4)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
|
||||
packetData, err := builder.Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build packet: %w", err)
|
||||
}
|
||||
|
||||
return m.TracePacket(packetData, builder.Direction), nil
|
||||
}
|
||||
|
||||
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
|
||||
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
trace := &PacketTrace{Direction: direction}
|
||||
|
||||
// Initial packet decoding
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||
return trace
|
||||
}
|
||||
|
||||
// Extract base packet info
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
trace.SourceIP = srcIP
|
||||
trace.DestinationIP = dstIP
|
||||
|
||||
// Determine protocol and ports
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
trace.Protocol = "TCP"
|
||||
trace.SourcePort = uint16(d.tcp.SrcPort)
|
||||
trace.DestinationPort = uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
trace.Protocol = "UDP"
|
||||
trace.SourcePort = uint16(d.udp.SrcPort)
|
||||
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
trace.Protocol = "ICMP"
|
||||
}
|
||||
|
||||
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
return m.traceOutbound(packetData, trace)
|
||||
}
|
||||
|
||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if !m.handleRouting(trace) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.nativeRouter {
|
||||
return m.handleNativeRouter(trace)
|
||||
}
|
||||
|
||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||
msg := "No existing connection found"
|
||||
if allowed {
|
||||
msg = m.buildConntrackStateMessage(d)
|
||||
trace.AddResult(StageConntrack, msg, true)
|
||||
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
|
||||
return true
|
||||
}
|
||||
trace.AddResult(StageConntrack, msg, false)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
msg := "Matched existing connection state"
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
|
||||
flags&conntrack.TCPSyn != 0,
|
||||
flags&conntrack.TCPAck != 0,
|
||||
flags&conntrack.TCPRst != 0,
|
||||
flags&conntrack.TCPFin != 0)
|
||||
case layers.LayerTypeICMPv4:
|
||||
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
if !m.localForwarding {
|
||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
msg := "Allowed by peer ACL rules"
|
||||
if blocked {
|
||||
msg = "Blocked by peer ACL rules"
|
||||
}
|
||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||
|
||||
if m.netstack {
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||
}
|
||||
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||
if !m.routingEnabled {
|
||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||
return false
|
||||
}
|
||||
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
|
||||
trace.AddResult(StageForwarding, "Forwarding via native router", true)
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||
proto := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
msg := "Allowed by route ACLs"
|
||||
if !allowed {
|
||||
msg = "Blocked by route ACLs"
|
||||
}
|
||||
trace.AddResult(StageRouteACL, msg, allowed)
|
||||
|
||||
if allowed && m.forwarder != nil {
|
||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||
}
|
||||
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
|
||||
fwdAction := &ForwarderAction{
|
||||
Action: action,
|
||||
RemoteAddr: remoteAddr,
|
||||
}
|
||||
trace.AddResultWithForwarder(StageForwarding,
|
||||
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.processOutgoingHooks(packetData)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
|
||||
}
|
||||
return trace
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -14,28 +17,48 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
|
||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
|
||||
var (
|
||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||
|
||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||
|
||||
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]Rule
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
type RouteRules []RouteRule
|
||||
|
||||
func (r RouteRules) Sort() {
|
||||
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
||||
// Deny rules come first
|
||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||
return -1
|
||||
}
|
||||
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.id, b.id)
|
||||
})
|
||||
}
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
@@ -43,17 +66,34 @@ type Manager struct {
|
||||
outgoingRules map[string]RuleSet
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[string]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
|
||||
mutex sync.RWMutex
|
||||
|
||||
stateful bool
|
||||
// indicates whether server routes are disabled
|
||||
disableServerRoutes bool
|
||||
// indicates whether we forward packets not destined for ourselves
|
||||
routingEnabled bool
|
||||
// indicates whether we leave forwarding and filtering to the native firewall
|
||||
nativeRouter bool
|
||||
// indicates whether we track outbound connections
|
||||
stateful bool
|
||||
// indicates whether wireguards runs in netstack mode
|
||||
netstack bool
|
||||
// indicates whether we forward local traffic to the native stack
|
||||
localForwarding bool
|
||||
|
||||
localipmanager *localIPManager
|
||||
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder *forwarder.Forwarder
|
||||
logger *nblog.Logger
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -70,22 +110,44 @@ type decoder struct {
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return create(iface)
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||
mgr, err := create(iface)
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr.nativeFirewall = nativeFirewall
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func create(iface IFaceMapper) (*Manager, error) {
|
||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||
func parseCreateEnv() (bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding bool
|
||||
var err error
|
||||
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||
disableConntrack, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
|
||||
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||
}
|
||||
}
|
||||
|
||||
return disableConntrack, enableLocalForwarding
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
@@ -101,52 +163,182 @@ func create(iface IFaceMapper) (*Manager, error) {
|
||||
return d
|
||||
},
|
||||
},
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
stateful: !disableConntrack,
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
routingEnabled: false,
|
||||
stateful: !disableConntrack,
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
}
|
||||
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
}
|
||||
|
||||
// Only initialize trackers if stateful mode is enabled
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
// netstack needs the forwarder for local traffic
|
||||
if m.netstack && m.localForwarding {
|
||||
if err := m.initForwarder(); err != nil {
|
||||
log.Errorf("failed to initialize forwarder: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.blockInvalidRouted(iface); err != nil {
|
||||
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("set filter: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
if m.forwarder == nil {
|
||||
return nil
|
||||
}
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse wireguard network: %w", err)
|
||||
}
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
if _, err := m.AddRouteFiltering(
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
wgPrefix,
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionDrop,
|
||||
); err != nil {
|
||||
return fmt.Errorf("block wg nte : %w", err)
|
||||
}
|
||||
|
||||
// TODO: Block networks that we're a client of
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) determineRouting() error {
|
||||
var disableUspRouting, forceUserspaceRouter bool
|
||||
var err error
|
||||
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
||||
disableUspRouting, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
|
||||
forceUserspaceRouter, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case disableUspRouting:
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
log.Info("userspace routing is disabled")
|
||||
|
||||
case m.disableServerRoutes:
|
||||
// if server routes are disabled we will let packets pass to the native stack
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
|
||||
log.Info("server routes are disabled")
|
||||
|
||||
case forceUserspaceRouter:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
|
||||
log.Info("native routing is enabled")
|
||||
|
||||
default:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
|
||||
log.Info("userspace routing enabled by default")
|
||||
}
|
||||
|
||||
if m.routingEnabled && !m.nativeRouter {
|
||||
return m.initForwarder()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
intf := m.wgIface.GetWGDevice()
|
||||
if intf == nil {
|
||||
m.routingEnabled = false
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
||||
if err != nil {
|
||||
m.routingEnabled = false
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
}
|
||||
|
||||
m.forwarder = forwarder
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Init(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
if m.nativeFirewall == nil {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
}
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
|
||||
// userspace routed packets are always SNATed to the inbound direction
|
||||
// TODO: implement outbound SNAT
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
}
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
@@ -162,7 +354,7 @@ func (m *Manager) AddPeerFiltering(
|
||||
_ string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
r := Rule{
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
@@ -205,18 +397,56 @@ func (m *Manager) AddPeerFiltering(
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errRouteNotSupported
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
id: ruleID,
|
||||
sources: sources,
|
||||
destination: destination,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
|
||||
m.routeRules = append(m.routeRules, rule)
|
||||
m.routeRules.Sort()
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := rule.GetRuleID()
|
||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||
return r.id == ruleID
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
}
|
||||
|
||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -224,7 +454,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, ok := rule.(*Rule)
|
||||
r, ok := rule.(*PeerRule)
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
@@ -255,10 +485,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.incomingRules)
|
||||
return m.dropFilter(packetData)
|
||||
}
|
||||
|
||||
// UpdateLocalIPs updates the list of local IPs
|
||||
func (m *Manager) UpdateLocalIPs() error {
|
||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||
}
|
||||
|
||||
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
@@ -279,18 +513,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Always process UDP hooks
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
// Track UDP state only if enabled
|
||||
if m.stateful {
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
}
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
// Track other protocols only if stateful mode is enabled
|
||||
// Track all protocols if stateful mode is enabled
|
||||
if m.stateful {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeUDP:
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeTCP:
|
||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeICMPv4:
|
||||
@@ -298,6 +525,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Process UDP hooks even if stateful mode is disabled
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -379,10 +611,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
}
|
||||
}
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
// TODO: Disable router if --disable-server-router is set
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
@@ -395,39 +626,129 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
|
||||
if !m.isWireguardTraffic(srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check connection state only if enabled
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
return m.applyRules(srcIP, packetData, rules, d)
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
||||
}
|
||||
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||
}
|
||||
|
||||
// handleLocalTraffic handles local traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
||||
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// if running in netstack mode we need to pass this to the forwarder
|
||||
if m.netstack {
|
||||
return m.handleNetstackLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
if !m.localForwarding {
|
||||
// pass to virtual tcp/ip stack to be picked up by listeners
|
||||
return false
|
||||
}
|
||||
|
||||
if m.forwarder == nil {
|
||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject local packet: %v", err)
|
||||
}
|
||||
|
||||
// don't process this packet further
|
||||
return true
|
||||
}
|
||||
|
||||
// handleRoutedTraffic handles routed traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass to native stack if native router is enabled or forced
|
||||
if m.nativeRouter {
|
||||
return false
|
||||
}
|
||||
|
||||
proto := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||
srcIP, srcPort, dstIP, dstPort, proto)
|
||||
return true
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||
}
|
||||
|
||||
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
|
||||
return true
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return firewall.ProtocolTCP
|
||||
case layers.LayerTypeUDP:
|
||||
return firewall.ProtocolUDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return firewall.ProtocolICMP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
|
||||
default:
|
||||
return 0, 0
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
log.Tracef("couldn't decode layer, err: %s", err)
|
||||
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
log.Tracef("not enough levels in network packet")
|
||||
m.logger.Trace("packet doesn't have network and transport layers")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
||||
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
@@ -462,7 +783,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
|
||||
func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
if d.decoded[1] != layers.LayerTypeICMPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
icmpType := d.icmp4.TypeCode.Type()
|
||||
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
if m.isSpecialICMP(d) {
|
||||
return false
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||
return filter
|
||||
}
|
||||
@@ -496,7 +832,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
@@ -533,6 +869,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
return false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||
return rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
if !rule.destination.Contains(dstAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
sourceMatched := false
|
||||
for _, src := range rule.sources {
|
||||
if src.Contains(srcAddr) {
|
||||
sourceMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sourceMatched {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
return false
|
||||
}
|
||||
|
||||
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
m.wgNetwork = network
|
||||
@@ -544,7 +925,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
func (m *Manager) AddUDPPacketHook(
|
||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||
) string {
|
||||
r := Rule{
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
@@ -561,12 +942,12 @@ func (m *Manager) AddUDPPacketHook(
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
@@ -599,3 +980,41 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
func (m *Manager) SetLogLevel(level log.Level) {
|
||||
if m.logger != nil {
|
||||
m.logger.SetLevel(nblog.Level(level))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.determineRouting()
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.forwarder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
|
||||
// don't stop forwarder if in use by netstack
|
||||
if m.netstack && m.localForwarding {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.forwarder.Stop()
|
||||
m.forwarder = nil
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
//go:build uspbench
|
||||
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(testIn, manager.incomingRules)
|
||||
manager.dropFilter(testIn)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
manager.processOutgoingHooks(syn)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
manager.dropFilter(synack)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
@@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
@@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// First outbound data
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
// Connection establishment
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response, manager.incomingRules)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
// Connection teardown
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
@@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
@@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
@@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response, manager.incomingRules)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
@@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func BenchmarkRouteACLs(b *testing.B) {
|
||||
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||
|
||||
// Add several route rules to simulate real-world scenario
|
||||
rules := []struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
proto fw.Protocol
|
||||
port *fw.Port
|
||||
}{
|
||||
{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
port: &fw.Port{Values: []uint16{80, 443}},
|
||||
},
|
||||
{
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/12"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: fw.ProtocolICMP,
|
||||
},
|
||||
{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: fw.ProtocolUDP,
|
||||
port: &fw.Port{Values: []uint16{53}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
_, err := manager.AddRouteFiltering(
|
||||
r.sources,
|
||||
r.dest,
|
||||
r.proto,
|
||||
nil,
|
||||
r.port,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases that exercise different matching scenarios
|
||||
cases := []struct {
|
||||
srcIP string
|
||||
dstIP string
|
||||
proto fw.Protocol
|
||||
dstPort uint16
|
||||
}{
|
||||
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
|
||||
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
|
||||
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
|
||||
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, tc := range cases {
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
1015
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
1015
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,17 +9,38 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||
if i.GetWGDeviceFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return i.GetWGDeviceFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
|
||||
if i.GetDeviceFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return i.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
@@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -95,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -166,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule Rule
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
@@ -215,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -247,9 +268,18 @@ func TestManagerReset(t *testing.T) {
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -298,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes(), m.incomingRules) {
|
||||
if m.dropFilter(buf.Bytes()) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
@@ -317,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface)
|
||||
manager, err := Create(iface, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
@@ -363,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -371,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
@@ -449,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, false)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -476,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
@@ -485,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
@@ -606,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
|
||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
@@ -677,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
|
||||
drop = manager.dropFilter(testBuf.Bytes())
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -152,46 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
}
|
||||
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
|
||||
} else if ok && addr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
// it will break the applications that are already using unspecified UDP connection
|
||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
default:
|
||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
||||
}
|
||||
if len(networks) > 0 {
|
||||
if params.Net == nil {
|
||||
var err error
|
||||
if params.Net, err = stdnet.NewNet(); err != nil {
|
||||
params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||
}
|
||||
} else {
|
||||
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &UDPMuxDefault{
|
||||
mux := &UDPMuxDefault{
|
||||
addressMap: map[string][]*udpMuxedConn{},
|
||||
params: params,
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
@@ -203,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
return newBufferHolder(receiveMTU + maxAddrSize)
|
||||
},
|
||||
},
|
||||
localAddrsForUnspecified: localAddrsForUnspecified,
|
||||
}
|
||||
|
||||
mux.updateLocalAddresses()
|
||||
return mux
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) updateLocalAddresses() {
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
|
||||
} else if ok && addr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
// it will break the applications that are already using unspecified UDP connection
|
||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
default:
|
||||
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
|
||||
}
|
||||
if len(networks) > 0 {
|
||||
if m.params.Net == nil {
|
||||
var err error
|
||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||
}
|
||||
} else {
|
||||
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.localAddrsForUnspecified = localAddrsForUnspecified
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||
@@ -214,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
||||
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
m.updateLocalAddresses()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if len(m.localAddrsForUnspecified) > 0 {
|
||||
return m.localAddrsForUnspecified
|
||||
return slices.Clone(m.localAddrsForUnspecified)
|
||||
}
|
||||
|
||||
return []net.Addr{m.LocalAddr()}
|
||||
@@ -225,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
// creates the connection if an existing one can't be found
|
||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
||||
// don't check addr for mux using unspecified address
|
||||
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
||||
m.mu.Lock()
|
||||
lenLocalAddrs := len(m.localAddrsForUnspecified)
|
||||
m.mu.Unlock()
|
||||
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
||||
return nil, fmt.Errorf("invalid address %s", addr.String())
|
||||
}
|
||||
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
||||
// WgInterfaceDefault is a default interface name of Netbird
|
||||
const WgInterfaceDefault = "wt0"
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
||||
// WgInterfaceDefault is a default interface name of Netbird
|
||||
const WgInterfaceDefault = "utun100"
|
||||
|
||||
@@ -362,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.NetbirdFwmark
|
||||
}
|
||||
return 0
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@@ -15,4 +19,6 @@ type WGTunDevice interface {
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -63,7 +64,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||
|
||||
log.Debugf("attaching to interface %v", name)
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
@@ -130,6 +131,10 @@ func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
func routesToString(routes []string) string {
|
||||
return strings.Join(routes, ";")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -117,6 +118,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||
@@ -138,3 +144,7 @@ func (t *TunDevice) assignAddr() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -64,7 +65,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
|
||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||
log.Debug("Attaching to interface")
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
@@ -131,3 +132,7 @@ func (t *TunDevice) UpdateAddr(addr WGAddress) error {
|
||||
func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -151,6 +153,11 @@ func (t *TunKernelDevice) DeviceName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Device returns the wireguard device, not applicable for kernel devices
|
||||
func (t *TunKernelDevice) Device() *device.Device {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||
return nil
|
||||
}
|
||||
@@ -159,3 +166,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||
func (t *TunKernelDevice) assignAddr() error {
|
||||
return t.link.assignAddr(t.address)
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
@@ -25,9 +27,11 @@ type TunNetstackDevice struct {
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
nsTun *netstack.NetStackTun
|
||||
nsTun *nbnetstack.NetStackTun
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
|
||||
net *netstack.Net
|
||||
}
|
||||
|
||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
@@ -43,13 +47,19 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||
log.Info("create netstack tun interface")
|
||||
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
|
||||
tunIface, err := t.nsTun.Create()
|
||||
log.Info("create nbnetstack tun interface")
|
||||
|
||||
// TODO: get from service listener runtime IP
|
||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
log.Debugf("netstack using address: %s", t.address.IP)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||
tunIface, net, err := t.nsTun.Create()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating tun device: %s", err)
|
||||
}
|
||||
t.filteredDevice = newDeviceFilter(tunIface)
|
||||
t.net = net
|
||||
|
||||
t.device = device.NewDevice(
|
||||
t.filteredDevice,
|
||||
@@ -117,3 +127,12 @@ func (t *TunNetstackDevice) DeviceName() string {
|
||||
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunNetstackDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) GetNet() *netstack.Net {
|
||||
return t.net
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -124,9 +125,18 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *USPDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *USPDevice) assignAddr() error {
|
||||
link := newWGLink(t.name)
|
||||
|
||||
return link.assignAddr(t.address)
|
||||
}
|
||||
|
||||
func (t *USPDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
@@ -150,6 +151,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||
if t.nativeTunDevice == nil {
|
||||
return "", fmt.Errorf("interface has not been initialized yet")
|
||||
@@ -169,3 +175,7 @@ func (t *TunDevice) assignAddr() error {
|
||||
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
|
||||
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@@ -13,4 +17,6 @@ type WGTunDevice interface {
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -9,8 +9,11 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -203,6 +206,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
|
||||
return w.tun.FilteredDevice()
|
||||
}
|
||||
|
||||
// GetWGDevice returns the WireGuard device
|
||||
func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||
return w.tun.Device()
|
||||
}
|
||||
|
||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return w.configurer.GetStats(peerKey)
|
||||
@@ -234,3 +242,11 @@ func (w *WGIface) waitUntilRemoved() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetNet returns the netstack.Net for the netstack device
|
||||
func (w *WGIface) GetNet() *netstack.Net {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return w.tun.GetNet()
|
||||
}
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBindFunc func() bool
|
||||
NameFunc func() string
|
||||
AddressFunc func() device.WGAddress
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
AddAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
CloseFunc func() error
|
||||
SetFilterFunc func(filter device.PacketFilter) error
|
||||
GetFilterFunc func() device.PacketFilter
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return m.GetInterfaceGUIDStringFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Create() error {
|
||||
return m.CreateFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
|
||||
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) IsUserspaceBind() bool {
|
||||
return m.IsUserspaceBindFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Name() string {
|
||||
return m.NameFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Address() device.WGAddress {
|
||||
return m.AddressFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) ToInterface() *net.Interface {
|
||||
return m.ToInterfaceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return m.UpFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdateAddr(newAddr string) error {
|
||||
return m.UpdateAddrFunc(newAddr)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemovePeer(peerKey string) error {
|
||||
return m.RemovePeerFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.AddAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Close() error {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
|
||||
return m.SetFilterFunc(filter)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetFilter() device.PacketFilter {
|
||||
return m.GetFilterFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
||||
return m.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type IWGIface interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close() error
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
|
||||
|
||||
// IsEnabled todo: move these function to cmd layer
|
||||
func IsEnabled() bool {
|
||||
return os.Getenv("NB_USE_NETSTACK_MODE") == "true"
|
||||
return os.Getenv(EnvUseNetstackMode) == "true"
|
||||
}
|
||||
|
||||
func ListenAddr() string {
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
|
||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||
|
||||
type NetStackTun struct { //nolint:revive
|
||||
address string
|
||||
address net.IP
|
||||
dnsAddress net.IP
|
||||
mtu int
|
||||
listenAddress string
|
||||
|
||||
@@ -17,29 +24,48 @@ type NetStackTun struct { //nolint:revive
|
||||
tundev tun.Device
|
||||
}
|
||||
|
||||
func NewNetStackTun(listenAddress string, address string, mtu int) *NetStackTun {
|
||||
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
||||
return &NetStackTun{
|
||||
address: address,
|
||||
dnsAddress: dnsAddress,
|
||||
mtu: mtu,
|
||||
listenAddress: listenAddress,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Create() (tun.Device, error) {
|
||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
addr, ok := netip.AddrFromSlice(t.address)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
||||
}
|
||||
|
||||
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
||||
}
|
||||
|
||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||
[]netip.Addr{netip.MustParseAddr(t.address)},
|
||||
[]netip.Addr{},
|
||||
[]netip.Addr{addr.Unmap()},
|
||||
[]netip.Addr{dnsAddr.Unmap()},
|
||||
t.mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
t.tundev = nsTunDev
|
||||
|
||||
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err)
|
||||
}
|
||||
if skipProxy {
|
||||
return nsTunDev, tunNet, nil
|
||||
}
|
||||
|
||||
dialer := NewNSDialer(tunNet)
|
||||
t.proxy, err = NewSocks5(dialer)
|
||||
if err != nil {
|
||||
_ = t.tundev.Close()
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -49,7 +75,7 @@ func (t *NetStackTun) Create() (tun.Device, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return nsTunDev, nil
|
||||
return nsTunDev, tunNet, nil
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Close() error {
|
||||
|
||||
@@ -268,7 +268,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if r.PortInfo != nil {
|
||||
if !portInfoEmpty(r.PortInfo) {
|
||||
port = convertPortInfo(r.PortInfo)
|
||||
} else if r.Port != "" {
|
||||
// old version of management, single port
|
||||
@@ -305,6 +305,22 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
return ruleID, rules, nil
|
||||
}
|
||||
|
||||
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
if portInfo == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
switch portInfo.GetPortSelection().(type) {
|
||||
case *mgmProto.PortInfo_Port:
|
||||
return portInfo.GetPort() == 0
|
||||
case *mgmProto.PortInfo_Range_:
|
||||
r := portInfo.GetRange()
|
||||
return r == nil || r.Start == 0 || r.End == 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
@@ -491,7 +507,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
|
||||
@@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) {
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
@@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
iface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
||||
}
|
||||
|
||||
// GetDevice mocks base method.
|
||||
func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDevice")
|
||||
ret0, _ := ret[0].(*device.FilteredDevice)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetDevice indicates an expected call of GetDevice.
|
||||
func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice))
|
||||
}
|
||||
|
||||
// GetWGDevice mocks base method.
|
||||
func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetWGDevice")
|
||||
ret0, _ := ret[0].(*wgdevice.Device)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetWGDevice indicates an expected call of GetWGDevice.
|
||||
func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -11,7 +13,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
// HostedGrantType grant type for device flow on Hosted
|
||||
@@ -56,6 +61,18 @@ func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*Devi
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
certPool = embeddedroots.Get()
|
||||
} else {
|
||||
log.Debug("Using system certificate pool.")
|
||||
}
|
||||
|
||||
httpTransport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: certPool,
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -68,6 +70,10 @@ type ConfigInput struct {
|
||||
DisableFirewall *bool
|
||||
|
||||
BlockLANAccess *bool
|
||||
|
||||
DisableNotifications *bool
|
||||
|
||||
DNSLabels domain.List
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -93,6 +99,10 @@ type Config struct {
|
||||
|
||||
BlockLANAccess bool
|
||||
|
||||
DisableNotifications *bool
|
||||
|
||||
DNSLabels domain.List
|
||||
|
||||
// SSHKey is a private SSH key in a PEM format
|
||||
SSHKey string
|
||||
|
||||
@@ -469,6 +479,23 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||
if *input.DisableNotifications {
|
||||
log.Infof("disabling notifications")
|
||||
} else {
|
||||
log.Infof("enabling notifications")
|
||||
}
|
||||
config.DisableNotifications = input.DisableNotifications
|
||||
updated = true
|
||||
}
|
||||
|
||||
if config.DisableNotifications == nil {
|
||||
disabled := true
|
||||
config.DisableNotifications = &disabled
|
||||
log.Infof("setting notifications to disabled by default")
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ClientCertKeyPath != "" {
|
||||
config.ClientCertKeyPath = input.ClientCertKeyPath
|
||||
updated = true
|
||||
@@ -489,6 +516,14 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if input.DNSLabels != nil && !slices.Equal(config.DNSLabels, input.DNSLabels) {
|
||||
log.Infof("updating DNS labels [ %s ] (old value: [ %s ])",
|
||||
input.DNSLabels.SafeString(),
|
||||
config.DNSLabels.SafeString())
|
||||
config.DNSLabels = input.DNSLabels
|
||||
updated = true
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
@@ -31,6 +32,7 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -59,13 +61,8 @@ func NewConnectClient(
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run() error {
|
||||
return c.run(MobileDependency{}, nil, nil)
|
||||
}
|
||||
|
||||
// RunWithProbes runs the client's main logic with probes attached
|
||||
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
|
||||
return c.run(MobileDependency{}, probes, runningChan)
|
||||
func (c *ConnectClient) Run(runningChan chan error) error {
|
||||
return c.run(MobileDependency{}, runningChan)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
@@ -84,7 +81,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
@@ -102,18 +99,30 @@ func (c *ConnectClient) RunOniOS(
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, nil)
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
if rec != nil {
|
||||
rec.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL, cProto.SystemEvent_SYSTEM,
|
||||
"panic occurred",
|
||||
"The Netbird service panicked. Please restart the service and submit a bug report with the client logs.",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||
|
||||
nbnet.Init()
|
||||
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -182,7 +191,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
}
|
||||
}()
|
||||
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||
if err != nil {
|
||||
log.Debug(err)
|
||||
@@ -204,8 +213,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||
|
||||
signalURL := fmt.Sprintf("%s://%s",
|
||||
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
|
||||
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
||||
strings.ToLower(loginResp.GetNetbirdConfig().GetSignal().GetProtocol().String()),
|
||||
loginResp.GetNetbirdConfig().GetSignal().GetUri(),
|
||||
)
|
||||
|
||||
c.statusRecorder.UpdateSignalAddress(signalURL)
|
||||
@@ -216,8 +225,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
c.statusRecorder.MarkSignalDisconnected(err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
||||
// with the global Netbird config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetNetbirdConfig(), myPrivateKey)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -261,7 +270,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
@@ -316,7 +325,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
}
|
||||
|
||||
func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
|
||||
relayCfg := loginResp.GetWiretrusteeConfig().GetRelay()
|
||||
relayCfg := loginResp.GetNetbirdConfig().GetRelay()
|
||||
if relayCfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -445,7 +454,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
}
|
||||
|
||||
// connectToSignal creates Signal Service client and established a connection
|
||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
||||
var sigTLSEnabled bool
|
||||
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
|
||||
sigTLSEnabled = true
|
||||
@@ -462,7 +471,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
|
||||
return signalClient, nil
|
||||
}
|
||||
|
||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
@@ -480,7 +489,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.DisableDNS,
|
||||
config.DisableFirewall,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
111
client/internal/dns.go
Normal file
111
client/internal/dns.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
||||
ip := net.ParseIP(aRecord.RData)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
if !ipNet.Contains(ip) {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
ipOctets := strings.Split(ip.String(), ".")
|
||||
slices.Reverse(ipOctets)
|
||||
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
|
||||
|
||||
return nbdns.SimpleRecord{
|
||||
Name: rdnsName,
|
||||
Type: int(dns.TypePTR),
|
||||
Class: aRecord.Class,
|
||||
TTL: aRecord.TTL,
|
||||
RData: dns.Fqdn(aRecord.Name),
|
||||
}, true
|
||||
}
|
||||
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||
maskOnes, _ := ipNet.Mask.Size()
|
||||
|
||||
// round up to nearest byte
|
||||
octetsToUse := (maskOnes + 7) / 8
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
for i := 0; i < octetsToUse; i++ {
|
||||
reverseOctets[octetsToUse-1-i] = octets[i]
|
||||
}
|
||||
|
||||
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
||||
}
|
||||
|
||||
// zoneExists checks if a zone with the given name already exists in the configuration
|
||||
func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.Domain == zoneName {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
for _, record := range zone.Records {
|
||||
if record.Type != int(dns.TypeA) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
||||
records = append(records, ptrRecord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
zoneName, err := generateReverseZoneName(ipNet)
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return
|
||||
}
|
||||
|
||||
if zoneExists(config, zoneName) {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return
|
||||
}
|
||||
|
||||
records := collectPTRRecords(config, ipNet)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
}
|
||||
|
||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
|
||||
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
return ErrRouteAllWithoutNameserverGroup
|
||||
}
|
||||
|
||||
if !backupFileExist {
|
||||
@@ -121,6 +121,10 @@ func (f *fileConfigurator) restoreHostDNS() error {
|
||||
return f.restore()
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) string() string {
|
||||
return "file"
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) backup() error {
|
||||
stats, err := os.Stat(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
const (
|
||||
PriorityDNSRoute = 100
|
||||
PriorityMatchDomain = 50
|
||||
PriorityDefault = 0
|
||||
PriorityDefault = 1
|
||||
)
|
||||
|
||||
type SubdomainMatcher interface {
|
||||
@@ -26,7 +26,6 @@ type HandlerEntry struct {
|
||||
Pattern string
|
||||
OrigPattern string
|
||||
IsWildcard bool
|
||||
StopHandler handlerWithStop
|
||||
MatchSubdomains bool
|
||||
}
|
||||
|
||||
@@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
|
||||
}
|
||||
|
||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||
if c.handlers[i].StopHandler != nil {
|
||||
c.handlers[i].StopHandler.stop()
|
||||
}
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
break
|
||||
}
|
||||
@@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
Pattern: pattern,
|
||||
OrigPattern: origPattern,
|
||||
IsWildcard: isWildcard,
|
||||
StopHandler: stopHandler,
|
||||
MatchSubdomains: matchSubdomains,
|
||||
}
|
||||
|
||||
@@ -142,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
if entry.StopHandler != nil {
|
||||
entry.StopHandler.stop()
|
||||
}
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
return
|
||||
}
|
||||
@@ -180,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("current handlers (%d):", len(handlers))
|
||||
for _, h := range handlers {
|
||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
if !matched {
|
||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
|
||||
@@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
dnsRouteHandler := &nbdns.MockHandler{}
|
||||
|
||||
// Setup handlers with different priorities
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
@@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
pattern = "*." + tt.handlerDomain[2:]
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
@@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||
}
|
||||
|
||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
|
||||
}
|
||||
|
||||
// Create and execute request
|
||||
@@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
handler3 := &nbdns.MockHandler{}
|
||||
|
||||
// Add handlers in priority order
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
@@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
if op.action == "add" {
|
||||
handler := &nbdns.MockHandler{}
|
||||
handlers[op.priority] = handler
|
||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||
chain.AddHandler(op.pattern, handler, op.priority)
|
||||
} else {
|
||||
chain.RemoveHandler(op.pattern, op.priority)
|
||||
}
|
||||
@@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
r.SetQuestion(testQuery, dns.TypeA)
|
||||
|
||||
// Add handlers in mixed order
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||
|
||||
// Test 1: Initial state with all three handlers
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
@@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
handler = mockHandler
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||
chain.AddHandler(pattern, handler, h.priority)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
@@ -795,7 +795,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
if op.action == "add" {
|
||||
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
||||
handlers[op.pattern] = handler
|
||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||
chain.AddHandler(op.pattern, handler, op.priority)
|
||||
} else {
|
||||
chain.RemoveHandler(op.pattern, op.priority)
|
||||
}
|
||||
|
||||
@@ -9,10 +9,18 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa"
|
||||
ipv6ReverseZone = ".ip6.arpa"
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||
restoreHostDNS() error
|
||||
supportCustomPort() bool
|
||||
string() string
|
||||
}
|
||||
|
||||
type SystemDNSSettings struct {
|
||||
@@ -39,6 +47,7 @@ type mockHostConfigurator struct {
|
||||
restoreHostDNSFunc func() error
|
||||
supportCustomPortFunc func() bool
|
||||
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
|
||||
stringFunc func() string
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
@@ -62,6 +71,13 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) string() string {
|
||||
if m.stringFunc != nil {
|
||||
return m.stringFunc()
|
||||
}
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func newNoopHostMocker() hostManager {
|
||||
return &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil },
|
||||
@@ -94,9 +110,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
MatchOnly: false,
|
||||
MatchOnly: matchOnly,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -116,3 +133,7 @@ func (n noopHostConfigurator) restoreHostDNS() error {
|
||||
func (n noopHostConfigurator) supportCustomPort() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (n noopHostConfigurator) string() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
@@ -22,3 +22,7 @@ func (a androidHostManager) restoreHostDNS() error {
|
||||
func (a androidHostManager) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a androidHostManager) string() string {
|
||||
return "none"
|
||||
}
|
||||
|
||||
@@ -114,6 +114,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) string() string {
|
||||
return "scutil"
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) restoreHostDNS() error {
|
||||
keys := s.getRemovableKeysWithDefaults()
|
||||
for _, key := range keys {
|
||||
|
||||
@@ -38,3 +38,7 @@ func (a iosHostManager) restoreHostDNS() error {
|
||||
func (a iosHostManager) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a iosHostManager) string() string {
|
||||
return "none"
|
||||
}
|
||||
|
||||
@@ -1,35 +1,51 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
var (
|
||||
userenv = syscall.NewLazyDLL("userenv.dll")
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
||||
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
||||
)
|
||||
|
||||
const (
|
||||
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
||||
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
||||
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient`
|
||||
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\DnsPolicyConfig\NetBird-Match`
|
||||
|
||||
dnsPolicyConfigVersionKey = "Version"
|
||||
dnsPolicyConfigVersionValue = 2
|
||||
dnsPolicyConfigNameKey = "Name"
|
||||
dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers"
|
||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||
)
|
||||
|
||||
const (
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
|
||||
// RP_FORCE: Reapply all policies even if no policy change was detected
|
||||
rpForce = 0x1
|
||||
)
|
||||
|
||||
type registryConfigurator struct {
|
||||
guid string
|
||||
routingAll bool
|
||||
gpo bool
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
@@ -37,12 +53,20 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newHostManagerWithGuid(guid)
|
||||
}
|
||||
|
||||
func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
|
||||
var useGPO bool
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, gpoDnsPolicyRoot, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
log.Debugf("failed to open GPO DNS policy root: %v", err)
|
||||
} else {
|
||||
closer(k)
|
||||
useGPO = true
|
||||
log.Infof("detected GPO DNS policy configuration, using policy store")
|
||||
}
|
||||
|
||||
return ®istryConfigurator{
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -51,30 +75,23 @@ func (r *registryConfigurator) supportCustomPort() bool {
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
var err error
|
||||
if config.RouteAll {
|
||||
err = r.addDNSSetupForAll(config.ServerIP)
|
||||
if err != nil {
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
} else if r.routingAll {
|
||||
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
|
||||
if err != nil {
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
|
||||
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
var searchDomains, matchDomains []string
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
@@ -86,16 +103,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
err = r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil {
|
||||
return fmt.Errorf("add dns match policy: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns match policy: %w", err)
|
||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||
return fmt.Errorf("remove dns match policies: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.updateSearchDomains(searchDomains)
|
||||
if err != nil {
|
||||
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
}
|
||||
|
||||
@@ -103,9 +120,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding dns setup for all failed with error: %w", err)
|
||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip); err != nil {
|
||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||
}
|
||||
r.routingAll = true
|
||||
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
||||
@@ -113,64 +129,70 @@ func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
|
||||
_, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.QUERY_VALUE)
|
||||
if err == nil {
|
||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||
return fmt.Errorf("configure GPO DNS policy: %w", err)
|
||||
}
|
||||
|
||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||
return fmt.Errorf("configure local DNS policy: %w", err)
|
||||
}
|
||||
|
||||
if err := refreshGroupPolicy(); err != nil {
|
||||
log.Warnf("failed to refresh group policy: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||
return fmt.Errorf("configure local DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
||||
log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
|
||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip string) error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
|
||||
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||
}
|
||||
|
||||
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
|
||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigVersionKey, err)
|
||||
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||
}
|
||||
defer closer(regKey)
|
||||
|
||||
if err := regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue); err != nil {
|
||||
return fmt.Errorf("set %s: %w", dnsPolicyConfigVersionKey, err)
|
||||
}
|
||||
|
||||
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigNameKey, err)
|
||||
if err := regKey.SetStringsValue(dnsPolicyConfigNameKey, domains); err != nil {
|
||||
return fmt.Errorf("set %s: %w", dnsPolicyConfigNameKey, err)
|
||||
}
|
||||
|
||||
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigGenericDNSServersKey, err)
|
||||
if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip); err != nil {
|
||||
return fmt.Errorf("set %s: %w", dnsPolicyConfigGenericDNSServersKey, err)
|
||||
}
|
||||
|
||||
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigConfigOptionsKey, err)
|
||||
if err := regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue); err != nil {
|
||||
return fmt.Errorf("set %s: %w", dnsPolicyConfigConfigOptionsKey, err)
|
||||
}
|
||||
|
||||
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) restoreHostDNS() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
log.Errorf("remove registry key from dns policy config: %s", err)
|
||||
}
|
||||
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
func (r *registryConfigurator) string() string {
|
||||
return "registry"
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding search domain failed with error: %w", err)
|
||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
|
||||
|
||||
log.Infof("updated search domains: %s", domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -181,11 +203,9 @@ func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value str
|
||||
}
|
||||
defer closer(regKey)
|
||||
|
||||
err = regKey.SetStringValue(key, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %w", key, value, err)
|
||||
if err := regKey.SetStringValue(key, value); err != nil {
|
||||
return fmt.Errorf("set key %s=%s: %w", key, value, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -196,43 +216,91 @@ func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey st
|
||||
}
|
||||
defer closer(regKey)
|
||||
|
||||
err = regKey.DeleteValue(propertyKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting registry key %s for interface failed with error: %w", propertyKey, err)
|
||||
if err := regKey.DeleteValue(propertyKey); err != nil {
|
||||
return fmt.Errorf("delete registry key %s: %w", propertyKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
||||
var regKey registry.Key
|
||||
|
||||
regKeyPath := interfaceConfigPath + "\\" + r.guid
|
||||
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
||||
return regKey, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
|
||||
}
|
||||
|
||||
return regKey, nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
|
||||
if err := r.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via registry: %w", err)
|
||||
func (r *registryConfigurator) restoreHostDNS() error {
|
||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||
log.Errorf("remove dns match policies: %s", err)
|
||||
}
|
||||
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
var merr *multierror.Error
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err))
|
||||
}
|
||||
|
||||
if err := refreshGroupPolicy(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("refresh group policy: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
|
||||
return r.restoreHostDNS()
|
||||
}
|
||||
|
||||
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
||||
if err == nil {
|
||||
defer closer(k)
|
||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
||||
}
|
||||
if err != nil {
|
||||
log.Debugf("failed to open HKEY_LOCAL_MACHINE\\%s: %v", regKeyPath, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
closer(k)
|
||||
if err := registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath); err != nil {
|
||||
return fmt.Errorf("delete HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func refreshGroupPolicy() error {
|
||||
// refreshPolicyExFn.Call() panics if the func is not found
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("Recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
ret, _, err := refreshPolicyExFn.Call(
|
||||
// bMachine = TRUE (computer policy)
|
||||
uintptr(1),
|
||||
// dwOptions = RP_FORCE
|
||||
uintptr(rpForce),
|
||||
)
|
||||
|
||||
if ret == 0 {
|
||||
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||
return fmt.Errorf("RefreshPolicyEx failed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("RefreshPolicyEx failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -14,7 +15,7 @@ type registrationMap map[string]struct{}
|
||||
|
||||
type localResolver struct {
|
||||
registeredMap registrationMap
|
||||
records sync.Map
|
||||
records sync.Map // key: string (domain_class_type), value: []dns.RR
|
||||
}
|
||||
|
||||
func (d *localResolver) MatchSubdomains() bool {
|
||||
@@ -29,20 +30,26 @@ func (d *localResolver) String() string {
|
||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *localResolver) id() handlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) > 0 {
|
||||
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
}
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
|
||||
response := d.lookupRecord(r)
|
||||
if response != nil {
|
||||
replyMessage.Answer = append(replyMessage.Answer, response)
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(r)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError
|
||||
}
|
||||
@@ -53,37 +60,65 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
||||
if len(r.Question) == 0 {
|
||||
return nil
|
||||
}
|
||||
question := r.Question[0]
|
||||
record, found := d.records.Load(buildRecordKey(question.Name, question.Qclass, question.Qtype))
|
||||
question.Name = strings.ToLower(question.Name)
|
||||
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
|
||||
|
||||
value, found := d.records.Load(key)
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
return record.(dns.RR)
|
||||
}
|
||||
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||
fullRecord, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("register record: %w", err)
|
||||
records, ok := value.([]dns.RR)
|
||||
if !ok {
|
||||
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
|
||||
return nil
|
||||
}
|
||||
|
||||
fullRecord.Header().Rdlength = record.Len()
|
||||
// if there's more than one record, rotate them (round-robin)
|
||||
if len(records) > 1 {
|
||||
first := records[0]
|
||||
records = append(records[1:], first)
|
||||
d.records.Store(key, records)
|
||||
}
|
||||
|
||||
header := fullRecord.Header()
|
||||
d.records.Store(buildRecordKey(header.Name, header.Class, header.Rrtype), fullRecord)
|
||||
|
||||
return nil
|
||||
return records
|
||||
}
|
||||
|
||||
// registerRecord stores a new record by appending it to any existing list
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
|
||||
rr, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("register record: %w", err)
|
||||
}
|
||||
|
||||
rr.Header().Rdlength = record.Len()
|
||||
header := rr.Header()
|
||||
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
|
||||
|
||||
// load any existing slice of records, then append
|
||||
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
|
||||
records := existing.([]dns.RR)
|
||||
records = append(records, rr)
|
||||
|
||||
// store updated slice
|
||||
d.records.Store(key, records)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// deleteRecord removes *all* records under the recordKey.
|
||||
func (d *localResolver) deleteRecord(recordKey string) {
|
||||
d.records.Delete(dns.Fqdn(recordKey))
|
||||
}
|
||||
|
||||
// buildRecordKey consistently generates a key: name_class_type
|
||||
func buildRecordKey(name string, class, qType uint16) string {
|
||||
key := fmt.Sprintf("%s_%d_%d", name, class, qType)
|
||||
return key
|
||||
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
|
||||
}
|
||||
|
||||
func (d *localResolver) probeAvailability() {}
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
resolver := &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
}
|
||||
_ = resolver.registerRecord(testCase.inputRecord)
|
||||
_, _ = resolver.registerRecord(testCase.inputRecord)
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
|
||||
@@ -179,6 +179,10 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) string() string {
|
||||
return "network-manager"
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||
if err != nil {
|
||||
|
||||
@@ -91,7 +91,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
|
||||
if err != nil {
|
||||
log.Errorf("restore host dns: %s", err)
|
||||
}
|
||||
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
|
||||
return ErrRouteAllWithoutNameserverGroup
|
||||
}
|
||||
|
||||
searchDomainList := searchDomains(config)
|
||||
@@ -139,6 +139,10 @@ func (r *resolvconf) restoreHostDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvconf) string() string {
|
||||
return fmt.Sprintf("resolvconf (%s)", r.implType)
|
||||
}
|
||||
|
||||
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@@ -42,7 +43,12 @@ type Server interface {
|
||||
ProbeAvailability()
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[string]handlerWithStop
|
||||
type handlerID string
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
domain string
|
||||
groups []*nbdns.NameServerGroup
|
||||
}
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
@@ -52,7 +58,6 @@ type DefaultServer struct {
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
handlerPriorities map[string]int
|
||||
localResolver *localResolver
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
@@ -77,14 +82,17 @@ type handlerWithStop interface {
|
||||
dns.Handler
|
||||
stop()
|
||||
probeAvailability()
|
||||
id() handlerID
|
||||
}
|
||||
|
||||
type muxUpdate struct {
|
||||
type handlerWrapper struct {
|
||||
domain string
|
||||
handler handlerWithStop
|
||||
priority int
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[handlerID]handlerWrapper
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(
|
||||
ctx context.Context,
|
||||
@@ -158,13 +166,12 @@ func newDefaultServer(
|
||||
) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
handlerPriorities: make(map[string]int),
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
@@ -192,8 +199,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
||||
log.Warn("skipping empty domain")
|
||||
continue
|
||||
}
|
||||
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
||||
s.handlerPriorities[domain] = priority
|
||||
s.handlerChain.AddHandler(domain, handler, priority)
|
||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||
}
|
||||
}
|
||||
@@ -209,14 +215,15 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
||||
|
||||
for _, domain := range domains {
|
||||
if domain == "" {
|
||||
log.Warn("skipping empty domain")
|
||||
continue
|
||||
}
|
||||
|
||||
s.handlerChain.RemoveHandler(domain, priority)
|
||||
|
||||
// Only deregister from service if no handlers remain
|
||||
if !s.handlerChain.HasHandlers(domain) {
|
||||
if domain == "" {
|
||||
log.Warn("skipping empty domain")
|
||||
continue
|
||||
}
|
||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||
}
|
||||
}
|
||||
@@ -283,14 +290,24 @@ func (s *DefaultServer) Stop() {
|
||||
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||
|
||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||
s.hostsDNSHolder.set(hostsDnsList)
|
||||
|
||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||
if ok {
|
||||
// Check if there's any root handler
|
||||
var hasRootHandler bool
|
||||
for _, handler := range s.dnsMuxMap {
|
||||
if handler.domain == nbdns.RootZone {
|
||||
hasRootHandler = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasRootHandler {
|
||||
log.Debugf("on new host DNS config but skip to apply it")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||
s.addHostRootZone()
|
||||
}
|
||||
@@ -364,7 +381,7 @@ func (s *DefaultServer) ProbeAvailability() {
|
||||
go func(mux handlerWithStop) {
|
||||
defer wg.Done()
|
||||
mux.probeAvailability()
|
||||
}(mux)
|
||||
}(mux.handler)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -378,18 +395,22 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
return fmt.Errorf("local handler updater: %w", err)
|
||||
}
|
||||
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
return fmt.Errorf("upstream handler updater: %w", err)
|
||||
}
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
|
||||
// register local records
|
||||
s.updateLocalResolver(localRecordsByDomain)
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
hostUpdate := s.currentConfig
|
||||
@@ -401,6 +422,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
||||
log.Error(err)
|
||||
s.handleErrNoGroupaAll(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -419,42 +441,111 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
||||
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
||||
return
|
||||
}
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_WARNING, cProto.SystemEvent_DNS,
|
||||
"The host dns manager does not support match domains",
|
||||
"The host dns manager does not support match domains without a catch-all nameserver group.",
|
||||
map[string]string{"manager": s.hostManager.string()},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
customZones []nbdns.CustomZone,
|
||||
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
localRecords := make(map[string][]nbdns.SimpleRecord)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain)
|
||||
continue
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
|
||||
// group all records under this domain
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||
log.Warnf("received an invalid class type: %s", record.Class)
|
||||
continue
|
||||
}
|
||||
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
localRecords[key] = record
|
||||
|
||||
localRecords[key] = append(localRecords[key], record)
|
||||
}
|
||||
}
|
||||
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
|
||||
var muxUpdates []muxUpdate
|
||||
for _, nsGroup := range nameServerGroups {
|
||||
if len(nsGroup.NameServers) == 0 {
|
||||
log.Warn("received a nameserver group with empty nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
if !nsGroup.Primary && len(nsGroup.Domains) == 0 {
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
||||
|
||||
for _, domainGroup := range groupedNS {
|
||||
basePriority := PriorityMatchDomain
|
||||
if domainGroup.domain == nbdns.RootZone {
|
||||
basePriority = PriorityDefault
|
||||
}
|
||||
|
||||
updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
muxUpdates = append(muxUpdates, updates...)
|
||||
}
|
||||
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
|
||||
for i, nsGroup := range domainGroup.groups {
|
||||
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
||||
priority := basePriority - i
|
||||
|
||||
// Check if we're about to overlap with the next priority tier
|
||||
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
|
||||
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
||||
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
@@ -462,10 +553,12 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
domainGroup.domain,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
return nil, fmt.Errorf("create upstream resolver: %v", err)
|
||||
}
|
||||
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
||||
@@ -489,81 +582,51 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
// after some period defined by upstream it tries to reactivate self by calling this hook
|
||||
// everything we need here is just to re-apply current configuration because it already
|
||||
// contains this upstream settings (temporal deactivation not removed it)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority)
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
priority: PriorityDefault,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(nsGroup.Domains) == 0 {
|
||||
handler.stop()
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
handler.stop()
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
}
|
||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||
domain: domainGroup.domain,
|
||||
handler: handler,
|
||||
priority: priority,
|
||||
})
|
||||
}
|
||||
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
handlersByPriority := make(map[string]int)
|
||||
|
||||
var isContainRootUpdate bool
|
||||
|
||||
// First register new handlers
|
||||
for _, update := range muxUpdates {
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.domain] = update.handler
|
||||
handlersByPriority[update.domain] = update.priority
|
||||
|
||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||
existingHandler.stop()
|
||||
}
|
||||
|
||||
if update.domain == nbdns.RootZone {
|
||||
isContainRootUpdate = true
|
||||
}
|
||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||
existing.handler.stop()
|
||||
}
|
||||
|
||||
// Then deregister old handlers not in the update
|
||||
for key, existingHandler := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
var containsRootUpdate bool
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
if update.domain == nbdns.RootZone {
|
||||
containsRootUpdate = true
|
||||
}
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.handler.id()] = update
|
||||
}
|
||||
|
||||
// If there's no root update and we had a root handler, restore it
|
||||
if !containsRootUpdate {
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
if existing.domain == nbdns.RootZone {
|
||||
s.addHostRootZone()
|
||||
existingHandler.stop()
|
||||
} else {
|
||||
existingHandler.stop()
|
||||
// Deregister with the priority that was used to register
|
||||
if oldPriority, ok := s.handlerPriorities[key]; ok {
|
||||
s.deregisterHandler([]string{key}, oldPriority)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
s.handlerPriorities = handlersByPriority
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
|
||||
// remove old records that are no longer present
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
@@ -572,12 +635,18 @@ func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for key, record := range update {
|
||||
err := s.localResolver.registerRecord(record)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||
for _, recs := range update {
|
||||
for _, rec := range recs {
|
||||
// convert the record to a dns.RR and register
|
||||
key, err := s.localResolver.registerRecord(rec)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v",
|
||||
rec.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
@@ -593,6 +662,7 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
||||
func (s *DefaultServer) upstreamCallbacks(
|
||||
nsGroup *nbdns.NameServerGroup,
|
||||
handler dns.Handler,
|
||||
priority int,
|
||||
) (deactivate func(error), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
deactivate = func(err error) {
|
||||
@@ -609,18 +679,19 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.RouteAll = false
|
||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
|
||||
s.deregisterHandler([]string{nbdns.RootZone}, priority)
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.Domains {
|
||||
if _, found := removeIndex[item.Domain]; found {
|
||||
s.currentConfig.Domains[i].Disabled = true
|
||||
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
|
||||
s.deregisterHandler([]string{item.Domain}, priority)
|
||||
removeIndex[item.Domain] = i
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
s.handleErrNoGroupaAll(err)
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
@@ -635,8 +706,8 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, err, false)
|
||||
|
||||
}
|
||||
|
||||
reactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -646,7 +717,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
continue
|
||||
}
|
||||
s.currentConfig.Domains[i].Disabled = false
|
||||
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
||||
s.registerHandler([]string{domain}, handler, priority)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
@@ -654,11 +725,12 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.RouteAll = true
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
if s.hostManager != nil {
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
s.handleErrNoGroupaAll(err)
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
@@ -676,6 +748,7 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
@@ -732,5 +805,34 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
|
||||
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
||||
}
|
||||
|
||||
// groupNSGroupsByDomain groups nameserver groups by their match domains
|
||||
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
|
||||
domainMap := make(map[string][]*nbdns.NameServerGroup)
|
||||
|
||||
for _, group := range nsGroups {
|
||||
if group.Primary {
|
||||
domainMap[nbdns.RootZone] = append(domainMap[nbdns.RootZone], group)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, domain := range group.Domains {
|
||||
if domain == "" {
|
||||
continue
|
||||
}
|
||||
domainMap[domain] = append(domainMap[domain], group)
|
||||
}
|
||||
}
|
||||
|
||||
var result []nsGroupsByDomain
|
||||
for domain, groups := range domainMap {
|
||||
result = append(result, nsGroupsByDomain{
|
||||
domain: domain,
|
||||
groups: groups,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
@@ -88,6 +89,18 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []string
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, getNSHostPort(srv))
|
||||
}
|
||||
return &upstreamResolverBase{
|
||||
domain: domain,
|
||||
upstreamServers: srvs,
|
||||
cancel: func() {},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
@@ -140,15 +153,37 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
dummyHandler.id(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "New Config Should Succeed",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -164,8 +199,19 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
"local-resolver": handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
@@ -220,7 +266,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Fail",
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
@@ -239,12 +285,22 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
@@ -252,9 +308,15 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalMap: make(registrationMap),
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
@@ -294,7 +356,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -351,7 +413,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
@@ -403,7 +465,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
@@ -421,7 +483,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &localResolver{},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
@@ -498,7 +566,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
@@ -509,7 +577,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer dnsServer.Stop()
|
||||
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -562,9 +630,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
handlerChain: NewHandlerChain(),
|
||||
handlerPriorities: make(map[string]int),
|
||||
hostManager: hostManager,
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: hostManager,
|
||||
currentConfig: HostDNSConfig{
|
||||
Domains: []DomainConfig{
|
||||
{false, "domain0", false},
|
||||
@@ -572,7 +639,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
{false, "domain2", false},
|
||||
},
|
||||
},
|
||||
statusRecorder: &peer.Status{},
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
}
|
||||
|
||||
var domainsUpdate string
|
||||
@@ -593,7 +660,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
}, nil)
|
||||
}, nil, 0)
|
||||
|
||||
deactivate(nil)
|
||||
expected := "domain0,domain2"
|
||||
@@ -633,7 +700,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
|
||||
var dnsList []string
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false)
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -657,7 +724,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -749,7 +816,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -820,7 +887,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Fatalf("create stdnet: %v", err)
|
||||
return nil, err
|
||||
@@ -849,7 +916,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface)
|
||||
pf, err := uspfilter.Create(wgIface, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
@@ -903,8 +970,8 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||
Subdomains: true,
|
||||
}
|
||||
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -959,3 +1026,421 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockHandler struct {
|
||||
Id string
|
||||
}
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) stop() {}
|
||||
func (m *mockHandler) probeAvailability() {}
|
||||
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
|
||||
func (m *mockService) Listen() error { return nil }
|
||||
func (m *mockService) Stop() {}
|
||||
func (m *mockService) RuntimeIP() string { return "127.0.0.1" }
|
||||
func (m *mockService) RuntimePort() int { return 53 }
|
||||
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
func (m *mockService) DeregisterMux(string) {}
|
||||
|
||||
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
baseMatchHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
}
|
||||
|
||||
baseRootHandlers := registeredHandlerMap{
|
||||
"upstream-root1": {
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
"upstream-root2": {
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
},
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
}
|
||||
|
||||
baseMixedHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
"upstream-other": {
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialHandlers registeredHandlerMap
|
||||
updates []handlerWrapper
|
||||
expectedHandlers map[string]string // map[handlerID]domain
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Remove group1 from update",
|
||||
initialHandlers: baseMatchHandlers,
|
||||
updates: []handlerWrapper{
|
||||
// Only group2 remains
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group2": "example.com",
|
||||
},
|
||||
description: "When group1 is not included in the update, it should be removed while group2 remains",
|
||||
},
|
||||
{
|
||||
name: "Remove group2 from update",
|
||||
initialHandlers: baseMatchHandlers,
|
||||
updates: []handlerWrapper{
|
||||
// Only group1 remains
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
},
|
||||
description: "When group2 is not included in the update, it should be removed while group1 remains",
|
||||
},
|
||||
{
|
||||
name: "Add group3 in first position",
|
||||
initialHandlers: baseMatchHandlers,
|
||||
updates: []handlerWrapper{
|
||||
// Add group3 with highest priority
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group3",
|
||||
},
|
||||
priority: PriorityMatchDomain + 1,
|
||||
},
|
||||
// Keep existing groups with their original priorities
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-group2": "example.com",
|
||||
"upstream-group3": "example.com",
|
||||
},
|
||||
description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups",
|
||||
},
|
||||
{
|
||||
name: "Add group3 in last position",
|
||||
initialHandlers: baseMatchHandlers,
|
||||
updates: []handlerWrapper{
|
||||
// Keep existing groups with their original priorities
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
// Add group3 with lowest priority
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group3",
|
||||
},
|
||||
priority: PriorityMatchDomain - 2,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-group2": "example.com",
|
||||
"upstream-group3": "example.com",
|
||||
},
|
||||
description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups",
|
||||
},
|
||||
// Root zone tests
|
||||
{
|
||||
name: "Remove root1 from update",
|
||||
initialHandlers: baseRootHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
},
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-root2": ".",
|
||||
},
|
||||
description: "When root1 is not included in the update, it should be removed while root2 remains",
|
||||
},
|
||||
{
|
||||
name: "Remove root2 from update",
|
||||
initialHandlers: baseRootHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-root1": ".",
|
||||
},
|
||||
description: "When root2 is not included in the update, it should be removed while root1 remains",
|
||||
},
|
||||
{
|
||||
name: "Add root3 in first position",
|
||||
initialHandlers: baseRootHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root3",
|
||||
},
|
||||
priority: PriorityDefault + 1,
|
||||
},
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
},
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-root1": ".",
|
||||
"upstream-root2": ".",
|
||||
"upstream-root3": ".",
|
||||
},
|
||||
description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers",
|
||||
},
|
||||
{
|
||||
name: "Add root3 in last position",
|
||||
initialHandlers: baseRootHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
},
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root3",
|
||||
},
|
||||
priority: PriorityDefault - 2,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-root1": ".",
|
||||
"upstream-root2": ".",
|
||||
"upstream-root3": ".",
|
||||
},
|
||||
description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers",
|
||||
},
|
||||
// Mixed domain tests
|
||||
{
|
||||
name: "Update with mixed domains - remove one of duplicate domain",
|
||||
initialHandlers: baseMixedHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-other": "other.com",
|
||||
},
|
||||
description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains",
|
||||
},
|
||||
{
|
||||
name: "Update with mixed domains - add new domain",
|
||||
initialHandlers: baseMixedHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
{
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "new.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-new",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-group2": "example.com",
|
||||
"upstream-other": "other.com",
|
||||
"upstream-new": "new.com",
|
||||
},
|
||||
description: "When updating mixed domains, should maintain existing duplicates and add new domain",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
dnsMuxMap: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Perform the update
|
||||
server.updateMux(tt.updates)
|
||||
|
||||
// Verify the results
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
||||
"Number of handlers after update doesn't match expected")
|
||||
|
||||
// Check each expected handler
|
||||
for id, expectedDomain := range tt.expectedHandlers {
|
||||
handler, exists := server.dnsMuxMap[handlerID(id)]
|
||||
assert.True(t, exists, "Expected handler %s not found", id)
|
||||
if exists {
|
||||
assert.Equal(t, expectedDomain, handler.domain,
|
||||
"Domain mismatch for handler %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected handlers exist
|
||||
for handlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(handlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
|
||||
}
|
||||
|
||||
// Verify the handlerChain state and order
|
||||
previousPriority := 0
|
||||
for _, chainEntry := range server.handlerChain.handlers {
|
||||
// Verify priority order
|
||||
if previousPriority > 0 {
|
||||
assert.True(t, chainEntry.Priority <= previousPriority,
|
||||
"Handlers in chain not properly ordered by priority")
|
||||
}
|
||||
previousPriority = chainEntry.Priority
|
||||
|
||||
// Verify handler exists in mux
|
||||
foundInMux := false
|
||||
for _, muxEntry := range server.dnsMuxMap {
|
||||
if chainEntry.Handler == muxEntry.handler &&
|
||||
chainEntry.Priority == muxEntry.priority &&
|
||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||
foundInMux = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundInMux,
|
||||
"Handler in chain not found in dnsMuxMap")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user