mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
Compare commits
150 Commits
feature/ex
...
feature/op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
515ce9e3af | ||
|
|
89383b7f01 | ||
|
|
db34162733 | ||
|
|
bd761e2177 | ||
|
|
4e1b95a4c6 | ||
|
|
05993af7bf | ||
|
|
9d1cb00570 | ||
|
|
543731df45 | ||
|
|
e6628ec231 | ||
|
|
41d4dd2aff | ||
|
|
30bed57711 | ||
|
|
6960b68322 | ||
|
|
3b3aa18148 | ||
|
|
93045f3e3a | ||
|
|
fd3c1dea8e | ||
|
|
48aff7a26e | ||
|
|
83dfe8e3a3 | ||
|
|
38e10af2d9 | ||
|
|
99854a126a | ||
|
|
a75f982fcd | ||
|
|
e7a6483912 | ||
|
|
30ede299b8 | ||
|
|
e3b76448f3 | ||
|
|
e0de86d6c9 | ||
|
|
5204d07811 | ||
|
|
5ea24ba56e | ||
|
|
d30cf8706a | ||
|
|
15a2feb723 | ||
|
|
91b2f9fc51 | ||
|
|
76702c8a09 | ||
|
|
061f673a4f | ||
|
|
9505805313 | ||
|
|
704c67dec8 | ||
|
|
3ed2f08f3c | ||
|
|
4c83408f27 | ||
|
|
90bd39c740 | ||
|
|
dd0cf41147 | ||
|
|
22b2caffc6 | ||
|
|
c1f66d1354 | ||
|
|
ac0fe6025b | ||
|
|
c28657710a | ||
|
|
3875c29f6b | ||
|
|
9f32ccd453 | ||
|
|
1d1d057e7d | ||
|
|
3461b1bb90 | ||
|
|
3d2a2377c6 | ||
|
|
25f5f26527 | ||
|
|
bb0d5c5baf | ||
|
|
7938295190 | ||
|
|
9af532fe71 | ||
|
|
23a1473797 | ||
|
|
9c2dc05df1 | ||
|
|
40d56e5d29 | ||
|
|
fd23d0c28f | ||
|
|
4fff93a1f2 | ||
|
|
22beac1b1b | ||
|
|
bd7a65d798 | ||
|
|
2d76b058fc | ||
|
|
ea2d060f93 | ||
|
|
68b377a28c | ||
|
|
af50eb350f | ||
|
|
2475473227 | ||
|
|
846871913d | ||
|
|
6cba9c0818 | ||
|
|
f0672b87bc | ||
|
|
9b0fe2c8e5 | ||
|
|
abd57d1191 | ||
|
|
416f04c27a | ||
|
|
fc7c1e397f | ||
|
|
52a3ac6b06 | ||
|
|
0b3b50c705 | ||
|
|
042141db06 | ||
|
|
4a1aee1ae0 | ||
|
|
ba33572ec9 | ||
|
|
9d213e0b54 | ||
|
|
5dde044fa5 | ||
|
|
5a3d9e401f | ||
|
|
fde1a2196c | ||
|
|
0aeb87742a | ||
|
|
6d747b2f83 | ||
|
|
199bf73103 | ||
|
|
17f5abc653 | ||
|
|
aa935bdae3 | ||
|
|
452419c4c3 | ||
|
|
17b1099032 | ||
|
|
a4b9e93217 | ||
|
|
63d7957140 | ||
|
|
9a6814deff | ||
|
|
190698bcf2 | ||
|
|
468fa2940b | ||
|
|
79a0647a26 | ||
|
|
17ceb3bde8 | ||
|
|
5a8f1763a6 | ||
|
|
f64e73ca70 | ||
|
|
b085419ab8 | ||
|
|
d78b652ff7 | ||
|
|
7251150c1c | ||
|
|
b65c2f69b0 | ||
|
|
d8ce08d898 | ||
|
|
e1c50248d9 | ||
|
|
ce2d14c08e | ||
|
|
52fd9a575a | ||
|
|
9028c3c1f7 | ||
|
|
9357a587e9 | ||
|
|
a47c69c472 | ||
|
|
bbea4c3cc3 | ||
|
|
b7a6cbfaa5 | ||
|
|
e18bf565a2 | ||
|
|
51fa3c92c5 | ||
|
|
d65602f904 | ||
|
|
8d9e1fed5f | ||
|
|
e1eddd1cab | ||
|
|
0fbf72434e | ||
|
|
51f133fdc6 | ||
|
|
d5338c09dc | ||
|
|
8fd4166c53 | ||
|
|
9bc7b9e897 | ||
|
|
db3cba5e0f | ||
|
|
cb3408a10b | ||
|
|
0afd738509 | ||
|
|
cf87f1e702 | ||
|
|
e890fdae54 | ||
|
|
dd14db6478 | ||
|
|
88747e3e01 | ||
|
|
fb30931365 | ||
|
|
a7547b9990 | ||
|
|
62bacee8dc | ||
|
|
71cd2e3e03 | ||
|
|
bdf71ab7ff | ||
|
|
a2f2a6e21a | ||
|
|
f89332fcd2 | ||
|
|
8604add997 | ||
|
|
93cab49696 | ||
|
|
b6835d9467 | ||
|
|
846d486366 | ||
|
|
9c56f74235 | ||
|
|
25b3641be8 | ||
|
|
c41504b571 | ||
|
|
399493a954 | ||
|
|
4771fed64f | ||
|
|
88117f7d16 | ||
|
|
5ac9f9fe2f | ||
|
|
a7d6632298 | ||
|
|
d4194cba6a | ||
|
|
131d9f1bc7 | ||
|
|
f099e02b34 | ||
|
|
93646e6a13 | ||
|
|
67a2127fd7 | ||
|
|
dd7fcbd083 | ||
|
|
d5f330b9c0 |
18
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
18
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -2,15 +2,17 @@
|
|||||||
name: Bug/Issue report
|
name: Bug/Issue report
|
||||||
about: Create a report to help us improve
|
about: Create a report to help us improve
|
||||||
title: ''
|
title: ''
|
||||||
labels: ''
|
labels: ['triage-needed']
|
||||||
assignees: ''
|
assignees: ''
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**Describe the problem**
|
**Describe the problem**
|
||||||
|
|
||||||
A clear and concise description of what the problem is.
|
A clear and concise description of what the problem is.
|
||||||
|
|
||||||
**To Reproduce**
|
**To Reproduce**
|
||||||
|
|
||||||
Steps to reproduce the behavior:
|
Steps to reproduce the behavior:
|
||||||
1. Go to '...'
|
1. Go to '...'
|
||||||
2. Click on '....'
|
2. Click on '....'
|
||||||
@@ -18,13 +20,25 @@ Steps to reproduce the behavior:
|
|||||||
4. See error
|
4. See error
|
||||||
|
|
||||||
**Expected behavior**
|
**Expected behavior**
|
||||||
|
|
||||||
A clear and concise description of what you expected to happen.
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
**Are you using NetBird Cloud?**
|
||||||
|
|
||||||
|
Please specify whether you use NetBird Cloud or self-host NetBird's control plane.
|
||||||
|
|
||||||
|
**NetBird version**
|
||||||
|
|
||||||
|
`netbird version`
|
||||||
|
|
||||||
**NetBird status -d output:**
|
**NetBird status -d output:**
|
||||||
If applicable, add the output of the `netbird status -d` command
|
|
||||||
|
If applicable, add the `netbird status -d' command output.
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
|
|
||||||
If applicable, add screenshots to help explain your problem.
|
If applicable, add screenshots to help explain your problem.
|
||||||
|
|
||||||
**Additional context**
|
**Additional context**
|
||||||
|
|
||||||
Add any other context about the problem here.
|
Add any other context about the problem here.
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -2,7 +2,7 @@
|
|||||||
name: Feature request
|
name: Feature request
|
||||||
about: Suggest an idea for this project
|
about: Suggest an idea for this project
|
||||||
title: ''
|
title: ''
|
||||||
labels: ''
|
labels: ['feature-request']
|
||||||
assignees: ''
|
assignees: ''
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
8
.github/workflows/golang-test-darwin.yml
vendored
8
.github/workflows/golang-test-darwin.yml
vendored
@@ -32,8 +32,14 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
macos-go-
|
macos-go-
|
||||||
|
|
||||||
|
- name: Install libpcap
|
||||||
|
run: brew install libpcap
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||||
|
|||||||
24
.github/workflows/golang-test-linux.yml
vendored
24
.github/workflows/golang-test-linux.yml
vendored
@@ -14,8 +14,8 @@ jobs:
|
|||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
arch: ['386','amd64']
|
arch: [ '386','amd64' ]
|
||||||
store: ['jsonfile', 'sqlite']
|
store: [ 'jsonfile', 'sqlite' ]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
@@ -36,13 +36,20 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
@@ -64,11 +71,14 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Generate Iface Test bin
|
- name: Generate Iface Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
|
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
|
||||||
|
|
||||||
@@ -76,7 +86,7 @@ jobs:
|
|||||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||||
|
|
||||||
- name: Generate RouteManager Test bin
|
- name: Generate RouteManager Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
|
||||||
|
|
||||||
- name: Generate nftables Manager Test bin
|
- name: Generate nftables Manager Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||||
@@ -103,7 +113,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run Engine tests in docker with file store
|
- name: Run Engine tests in docker with file store
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Engine tests in docker with sqlite store
|
- name: Run Engine tests in docker with sqlite store
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
|
||||||
- name: test output
|
- name: test output
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: Get-Content test-out.txt
|
run: Get-Content test-out.txt
|
||||||
|
|||||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -33,6 +33,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
- name: Check for duplicate constants
|
||||||
|
if: matrix.os == 'ubuntu-latest'
|
||||||
|
run: |
|
||||||
|
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
@@ -40,7 +44,7 @@ jobs:
|
|||||||
cache: false
|
cache: false
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v3
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
name: Android build validation
|
name: Mobile build validation
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -11,7 +11,7 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
android_build:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -41,8 +41,25 @@ jobs:
|
|||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build android nebtird lib
|
- name: build android netbird lib
|
||||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||||
env:
|
env:
|
||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||||
|
ios_build:
|
||||||
|
runs-on: macos-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: "1.21.x"
|
||||||
|
- name: install gomobile
|
||||||
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||||
|
- name: gomobile init
|
||||||
|
run: gomobile init
|
||||||
|
- name: build iOS netbird lib
|
||||||
|
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
|
||||||
|
env:
|
||||||
|
CGO_ENABLED: 0
|
||||||
3
.github/workflows/release.yml
vendored
3
.github/workflows/release.yml
vendored
@@ -190,6 +190,9 @@ jobs:
|
|||||||
-
|
-
|
||||||
name: Install modules
|
name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
-
|
||||||
|
name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
-
|
-
|
||||||
name: Run GoReleaser
|
name: Run GoReleaser
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
|
|||||||
23
.github/workflows/test-infrastructure-files.yml
vendored
23
.github/workflows/test-infrastructure-files.yml
vendored
@@ -127,6 +127,9 @@ jobs:
|
|||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Build management binary
|
- name: Build management binary
|
||||||
working-directory: management
|
working-directory: management
|
||||||
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
|
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
|
||||||
@@ -159,6 +162,13 @@ jobs:
|
|||||||
test $count -eq 4
|
test $count -eq 4
|
||||||
working-directory: infrastructure_files/artifacts
|
working-directory: infrastructure_files/artifacts
|
||||||
|
|
||||||
|
- name: test geolocation databases
|
||||||
|
working-directory: infrastructure_files/artifacts
|
||||||
|
run: |
|
||||||
|
sleep 30
|
||||||
|
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
|
||||||
|
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
|
||||||
|
|
||||||
test-getting-started-script:
|
test-getting-started-script:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
@@ -186,3 +196,16 @@ jobs:
|
|||||||
run: test -f zitadel.env
|
run: test -f zitadel.env
|
||||||
- name: test dashboard.env file gen
|
- name: test dashboard.env file gen
|
||||||
run: test -f dashboard.env
|
run: test -f dashboard.env
|
||||||
|
test-download-geolite2-script:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install jq
|
||||||
|
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: test script
|
||||||
|
run: bash -x infrastructure_files/download-geolite2.sh
|
||||||
|
- name: test mmdb file exists
|
||||||
|
run: test -f GeoLite2-City.mmdb
|
||||||
|
- name: test geonames file exists
|
||||||
|
run: test -f geonames.db
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -29,4 +29,4 @@ infrastructure_files/setup.env
|
|||||||
infrastructure_files/setup-*.env
|
infrastructure_files/setup-*.env
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
*.db
|
GeoLite2-City*
|
||||||
@@ -63,6 +63,14 @@ linters-settings:
|
|||||||
enable:
|
enable:
|
||||||
- nilness
|
- nilness
|
||||||
|
|
||||||
|
revive:
|
||||||
|
rules:
|
||||||
|
- name: exported
|
||||||
|
severity: warning
|
||||||
|
disabled: false
|
||||||
|
arguments:
|
||||||
|
- "checkPrivateReceivers"
|
||||||
|
- "sayRepetitiveInsteadOfStutters"
|
||||||
tenv:
|
tenv:
|
||||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||||
@@ -93,6 +101,7 @@ linters:
|
|||||||
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||||
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||||
|
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||||
- wastedassign # wastedassign finds wasted assignment statements
|
- wastedassign # wastedassign finds wasted assignment statements
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ nfpms:
|
|||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-default.png
|
- src: client/ui/netbird-systemtray-connected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
@@ -71,7 +71,7 @@ nfpms:
|
|||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-default.png
|
- src: client/ui/netbird-systemtray-connected.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
|
|||||||
@@ -274,6 +274,8 @@ go test -exec sudo ./...
|
|||||||
```
|
```
|
||||||
> On Windows use a powershell with administrator privileges
|
> On Windows use a powershell with administrator privileges
|
||||||
|
|
||||||
|
> Non-GTK environments will need the `libayatana-appindicator3-dev` (debian/ubuntu) package installed
|
||||||
|
|
||||||
## Checklist before submitting a PR
|
## Checklist before submitting a PR
|
||||||
As a critical network service and open-source project, we must enforce a few things before submitting the pull-requests:
|
As a critical network service and open-source project, we must enforce a few things before submitting the pull-requests:
|
||||||
- Keep functions as simple as possible, with a single purpose
|
- Keep functions as simple as possible, with a single purpose
|
||||||
|
|||||||
44
README.md
44
README.md
@@ -1,6 +1,6 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
|
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
|
||||||
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
|
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
|
||||||
Learn more
|
Learn more
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
@@ -40,27 +40,25 @@
|
|||||||
|
|
||||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||||
|
|
||||||
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||||
|
|
||||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
### Open-Source Network Security in a Single Platform
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
|
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
| Connectivity | Management | Automation | Platforms |
|
| Connectivity | Management | Security | Automation | Platforms |
|
||||||
|---------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
|
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
||||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||||
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
||||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
|
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||||
| <ul><li> - \[x] Post-quantum-secure connection through [Rosenpass](https://rosenpass.eu) </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||||
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
|
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||||
| | <ul><li> - \[x] SSH access management </ul></li> | | |
|
|
||||||
|
|
||||||
|
|
||||||
### Quickstart with NetBird Cloud
|
### Quickstart with NetBird Cloud
|
||||||
|
|
||||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||||
@@ -79,7 +77,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
|||||||
- **Public domain** name pointing to the VM.
|
- **Public domain** name pointing to the VM.
|
||||||
|
|
||||||
**Software requirements:**
|
**Software requirements:**
|
||||||
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||||
- [curl](https://curl.se/) installed.
|
- [curl](https://curl.se/) installed.
|
||||||
@@ -96,9 +94,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
|||||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||||
- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||||
|
|
||||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||||
|
|
||||||
@@ -109,8 +107,8 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
|||||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||||
|
|
||||||
### Community projects
|
### Community projects
|
||||||
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
|
||||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||||
|
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||||
|
|
||||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||||
@@ -122,7 +120,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
|
|||||||

|

|
||||||
|
|
||||||
### Testimonials
|
### Testimonials
|
||||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||||
|
|
||||||
### Legal
|
### Legal
|
||||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||||
|
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||||
|
|
||||||
var ctx context.Context
|
var ctx context.Context
|
||||||
//nolint
|
//nolint
|
||||||
@@ -109,6 +110,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||||
|
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||||
|
|
||||||
var ctx context.Context
|
var ctx context.Context
|
||||||
//nolint
|
//nolint
|
||||||
@@ -139,6 +141,11 @@ func (c *Client) SetTraceLogLevel() {
|
|||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetInfoLogLevel configure the logger to info level
|
||||||
|
func (c *Client) SetInfoLogLevel() {
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
|
}
|
||||||
|
|
||||||
// PeersList return with the list of the PeerInfos
|
// PeersList return with the list of the PeerInfos
|
||||||
func (c *Client) PeersList() *PeerInfoArray {
|
func (c *Client) PeersList() *PeerInfoArray {
|
||||||
|
|
||||||
|
|||||||
@@ -82,12 +82,15 @@ var loginCmd = &cobra.Command{
|
|||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: setupKey,
|
SetupKey: setupKey,
|
||||||
PreSharedKey: preSharedKey,
|
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
|
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
|
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
|
|||||||
@@ -25,12 +25,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
externalIPMapFlag = "external-ip-map"
|
externalIPMapFlag = "external-ip-map"
|
||||||
dnsResolverAddress = "dns-resolver-address"
|
dnsResolverAddress = "dns-resolver-address"
|
||||||
enableRosenpassFlag = "enable-rosenpass"
|
enableRosenpassFlag = "enable-rosenpass"
|
||||||
preSharedKeyFlag = "preshared-key"
|
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||||
interfaceNameFlag = "interface-name"
|
preSharedKeyFlag = "preshared-key"
|
||||||
wireguardPortFlag = "wireguard-port"
|
interfaceNameFlag = "interface-name"
|
||||||
|
wireguardPortFlag = "wireguard-port"
|
||||||
|
disableAutoConnectFlag = "disable-auto-connect"
|
||||||
|
serverSSHAllowedFlag = "allow-server-ssh"
|
||||||
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -54,8 +58,13 @@ var (
|
|||||||
natExternalIPs []string
|
natExternalIPs []string
|
||||||
customDNSAddress string
|
customDNSAddress string
|
||||||
rosenpassEnabled bool
|
rosenpassEnabled bool
|
||||||
|
rosenpassPermissive bool
|
||||||
|
serverSSHAllowed bool
|
||||||
interfaceName string
|
interfaceName string
|
||||||
wireguardPort uint16
|
wireguardPort uint16
|
||||||
|
serviceName string
|
||||||
|
autoConnectDisabled bool
|
||||||
|
extraIFaceBlackList []string
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
Short: "",
|
Short: "",
|
||||||
@@ -94,9 +103,16 @@ func init() {
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
defaultDaemonAddr = "tcp://127.0.0.1:41731"
|
defaultDaemonAddr = "tcp://127.0.0.1:41731"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defaultServiceName := "netbird"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
defaultServiceName = "Netbird"
|
||||||
|
}
|
||||||
|
|
||||||
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
||||||
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
|
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
|
||||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
||||||
|
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
||||||
@@ -126,6 +142,9 @@ func init() {
|
|||||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
||||||
)
|
)
|
||||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||||
@@ -176,7 +195,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
|||||||
return prefix + upper
|
return prefix + upper
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialClientGRPCServer returns client connection to the dameno server.
|
// DialClientGRPCServer returns client connection to the daemon server.
|
||||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -24,12 +22,8 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newSVCConfig() *service.Config {
|
func newSVCConfig() *service.Config {
|
||||||
name := "netbird"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
name = "Netbird"
|
|
||||||
}
|
|
||||||
return &service.Config{
|
return &service.Config{
|
||||||
Name: name,
|
Name: serviceName,
|
||||||
DisplayName: "Netbird",
|
DisplayName: "Netbird",
|
||||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
||||||
Option: make(service.KeyValue),
|
Option: make(service.KeyValue),
|
||||||
|
|||||||
@@ -11,11 +11,12 @@ import (
|
|||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/spf13/cobra"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *program) Start(svc service.Service) error {
|
func (p *program) Start(svc service.Service) error {
|
||||||
@@ -109,7 +110,6 @@ var runCmd = &cobra.Command{
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cmd.Printf("Netbird service is running")
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,6 +64,10 @@ var installCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
svcConfig.Option["OnFailure"] = "restart"
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
||||||
@@ -77,6 +81,7 @@ var installCmd = &cobra.Command{
|
|||||||
cmd.PrintErrln(err)
|
cmd.PrintErrln(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Println("Netbird service has been installed")
|
cmd.Println("Netbird service has been installed")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -106,7 +111,7 @@ var uninstallCmd = &cobra.Command{
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cmd.Println("Netbird has been uninstalled")
|
cmd.Println("Netbird service has been uninstalled")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,14 +22,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type peerStateDetailOutput struct {
|
type peerStateDetailOutput struct {
|
||||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||||
Status string `json:"status" yaml:"status"`
|
Status string `json:"status" yaml:"status"`
|
||||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||||
Direct bool `json:"direct" yaml:"direct"`
|
Direct bool `json:"direct" yaml:"direct"`
|
||||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||||
|
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
||||||
|
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
||||||
|
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
||||||
|
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||||
|
Latency time.Duration `json:"latency" yaml:"latency"`
|
||||||
|
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||||
|
Routes []string `json:"routes" yaml:"routes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type peersStateOutput struct {
|
type peersStateOutput struct {
|
||||||
@@ -41,11 +48,25 @@ type peersStateOutput struct {
|
|||||||
type signalStateOutput struct {
|
type signalStateOutput struct {
|
||||||
URL string `json:"url" yaml:"url"`
|
URL string `json:"url" yaml:"url"`
|
||||||
Connected bool `json:"connected" yaml:"connected"`
|
Connected bool `json:"connected" yaml:"connected"`
|
||||||
|
Error string `json:"error" yaml:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type managementStateOutput struct {
|
type managementStateOutput struct {
|
||||||
URL string `json:"url" yaml:"url"`
|
URL string `json:"url" yaml:"url"`
|
||||||
Connected bool `json:"connected" yaml:"connected"`
|
Connected bool `json:"connected" yaml:"connected"`
|
||||||
|
Error string `json:"error" yaml:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type relayStateOutputDetail struct {
|
||||||
|
URI string `json:"uri" yaml:"uri"`
|
||||||
|
Available bool `json:"available" yaml:"available"`
|
||||||
|
Error string `json:"error" yaml:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type relayStateOutput struct {
|
||||||
|
Total int `json:"total" yaml:"total"`
|
||||||
|
Available int `json:"available" yaml:"available"`
|
||||||
|
Details []relayStateOutputDetail `json:"details" yaml:"details"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type iceCandidateType struct {
|
type iceCandidateType struct {
|
||||||
@@ -53,16 +74,28 @@ type iceCandidateType struct {
|
|||||||
Remote string `json:"remote" yaml:"remote"`
|
Remote string `json:"remote" yaml:"remote"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type nsServerGroupStateOutput struct {
|
||||||
|
Servers []string `json:"servers" yaml:"servers"`
|
||||||
|
Domains []string `json:"domains" yaml:"domains"`
|
||||||
|
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||||
|
Error string `json:"error" yaml:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
type statusOutputOverview struct {
|
type statusOutputOverview struct {
|
||||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
||||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||||
|
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||||
|
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||||
|
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||||
|
Routes []string `json:"routes" yaml:"routes"`
|
||||||
|
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -146,7 +179,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
case yamlFlag:
|
case yamlFlag:
|
||||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
statusOutputString, err = parseToYAML(outputInformationHolder)
|
||||||
default:
|
default:
|
||||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false)
|
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -220,37 +253,89 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
|
|||||||
managementOverview := managementStateOutput{
|
managementOverview := managementStateOutput{
|
||||||
URL: managementState.GetURL(),
|
URL: managementState.GetURL(),
|
||||||
Connected: managementState.GetConnected(),
|
Connected: managementState.GetConnected(),
|
||||||
|
Error: managementState.Error,
|
||||||
}
|
}
|
||||||
|
|
||||||
signalState := pbFullStatus.GetSignalState()
|
signalState := pbFullStatus.GetSignalState()
|
||||||
signalOverview := signalStateOutput{
|
signalOverview := signalStateOutput{
|
||||||
URL: signalState.GetURL(),
|
URL: signalState.GetURL(),
|
||||||
Connected: signalState.GetConnected(),
|
Connected: signalState.GetConnected(),
|
||||||
|
Error: signalState.Error,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
||||||
|
|
||||||
overview := statusOutputOverview{
|
overview := statusOutputOverview{
|
||||||
Peers: peersOverview,
|
Peers: peersOverview,
|
||||||
CliVersion: version.NetbirdVersion(),
|
CliVersion: version.NetbirdVersion(),
|
||||||
DaemonVersion: resp.GetDaemonVersion(),
|
DaemonVersion: resp.GetDaemonVersion(),
|
||||||
ManagementState: managementOverview,
|
ManagementState: managementOverview,
|
||||||
SignalState: signalOverview,
|
SignalState: signalOverview,
|
||||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
Relays: relayOverview,
|
||||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||||
|
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||||
|
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||||
|
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||||
|
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
|
||||||
|
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||||
}
|
}
|
||||||
|
|
||||||
return overview
|
return overview
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mapRelays(relays []*proto.RelayState) relayStateOutput {
|
||||||
|
var relayStateDetail []relayStateOutputDetail
|
||||||
|
|
||||||
|
var relaysAvailable int
|
||||||
|
for _, relay := range relays {
|
||||||
|
available := relay.GetAvailable()
|
||||||
|
relayStateDetail = append(relayStateDetail,
|
||||||
|
relayStateOutputDetail{
|
||||||
|
URI: relay.URI,
|
||||||
|
Available: available,
|
||||||
|
Error: relay.GetError(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if available {
|
||||||
|
relaysAvailable++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return relayStateOutput{
|
||||||
|
Total: len(relays),
|
||||||
|
Available: relaysAvailable,
|
||||||
|
Details: relayStateDetail,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
||||||
|
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
|
||||||
|
for _, pbNsGroupServer := range servers {
|
||||||
|
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
|
||||||
|
Servers: pbNsGroupServer.GetServers(),
|
||||||
|
Domains: pbNsGroupServer.GetDomains(),
|
||||||
|
Enabled: pbNsGroupServer.GetEnabled(),
|
||||||
|
Error: pbNsGroupServer.GetError(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return mappedNSGroups
|
||||||
|
}
|
||||||
|
|
||||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||||
var peersStateDetail []peerStateDetailOutput
|
var peersStateDetail []peerStateDetailOutput
|
||||||
localICE := ""
|
localICE := ""
|
||||||
remoteICE := ""
|
remoteICE := ""
|
||||||
|
localICEEndpoint := ""
|
||||||
|
remoteICEEndpoint := ""
|
||||||
connType := ""
|
connType := ""
|
||||||
peersConnected := 0
|
peersConnected := 0
|
||||||
|
lastHandshake := time.Time{}
|
||||||
|
transferReceived := int64(0)
|
||||||
|
transferSent := int64(0)
|
||||||
for _, pbPeerState := range peers {
|
for _, pbPeerState := range peers {
|
||||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||||
@@ -261,10 +346,15 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
|||||||
|
|
||||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||||
|
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
||||||
|
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
||||||
connType = "P2P"
|
connType = "P2P"
|
||||||
if pbPeerState.Relayed {
|
if pbPeerState.Relayed {
|
||||||
connType = "Relayed"
|
connType = "Relayed"
|
||||||
}
|
}
|
||||||
|
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||||
|
transferReceived = pbPeerState.GetBytesRx()
|
||||||
|
transferSent = pbPeerState.GetBytesTx()
|
||||||
}
|
}
|
||||||
|
|
||||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
||||||
@@ -279,7 +369,17 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
|||||||
Local: localICE,
|
Local: localICE,
|
||||||
Remote: remoteICE,
|
Remote: remoteICE,
|
||||||
},
|
},
|
||||||
FQDN: pbPeerState.GetFqdn(),
|
IceCandidateEndpoint: iceCandidateType{
|
||||||
|
Local: localICEEndpoint,
|
||||||
|
Remote: remoteICEEndpoint,
|
||||||
|
},
|
||||||
|
FQDN: pbPeerState.GetFqdn(),
|
||||||
|
LastWireguardHandshake: lastHandshake,
|
||||||
|
TransferReceived: transferReceived,
|
||||||
|
TransferSent: transferSent,
|
||||||
|
Latency: pbPeerState.GetLatency().AsDuration(),
|
||||||
|
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||||
|
Routes: pbPeerState.GetRoutes(),
|
||||||
}
|
}
|
||||||
|
|
||||||
peersStateDetail = append(peersStateDetail, peerState)
|
peersStateDetail = append(peersStateDetail, peerState)
|
||||||
@@ -329,22 +429,31 @@ func parseToYAML(overview statusOutputOverview) (string, error) {
|
|||||||
return string(yamlBytes), nil
|
return string(yamlBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
||||||
|
var managementConnString string
|
||||||
managementConnString := "Disconnected"
|
|
||||||
if overview.ManagementState.Connected {
|
if overview.ManagementState.Connected {
|
||||||
managementConnString = "Connected"
|
managementConnString = "Connected"
|
||||||
if showURL {
|
if showURL {
|
||||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
managementConnString = "Disconnected"
|
||||||
|
if overview.ManagementState.Error != "" {
|
||||||
|
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
signalConnString := "Disconnected"
|
var signalConnString string
|
||||||
if overview.SignalState.Connected {
|
if overview.SignalState.Connected {
|
||||||
signalConnString = "Connected"
|
signalConnString = "Connected"
|
||||||
if showURL {
|
if showURL {
|
||||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
signalConnString = "Disconnected"
|
||||||
|
if overview.SignalState.Error != "" {
|
||||||
|
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
interfaceTypeString := "Userspace"
|
interfaceTypeString := "Userspace"
|
||||||
@@ -356,6 +465,64 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
|||||||
interfaceIP = "N/A"
|
interfaceIP = "N/A"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var relaysString string
|
||||||
|
if showRelays {
|
||||||
|
for _, relay := range overview.Relays.Details {
|
||||||
|
available := "Available"
|
||||||
|
reason := ""
|
||||||
|
if !relay.Available {
|
||||||
|
available = "Unavailable"
|
||||||
|
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||||
|
}
|
||||||
|
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := "-"
|
||||||
|
if len(overview.Routes) > 0 {
|
||||||
|
sort.Strings(overview.Routes)
|
||||||
|
routes = strings.Join(overview.Routes, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
var dnsServersString string
|
||||||
|
if showNameServers {
|
||||||
|
for _, nsServerGroup := range overview.NSServerGroups {
|
||||||
|
enabled := "Available"
|
||||||
|
if !nsServerGroup.Enabled {
|
||||||
|
enabled = "Unavailable"
|
||||||
|
}
|
||||||
|
errorString := ""
|
||||||
|
if nsServerGroup.Error != "" {
|
||||||
|
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
|
||||||
|
errorString = strings.TrimSpace(errorString)
|
||||||
|
}
|
||||||
|
|
||||||
|
domainsString := strings.Join(nsServerGroup.Domains, ", ")
|
||||||
|
if domainsString == "" {
|
||||||
|
domainsString = "." // Show "." for the default zone
|
||||||
|
}
|
||||||
|
dnsServersString += fmt.Sprintf(
|
||||||
|
"\n [%s] for [%s] is %s%s",
|
||||||
|
strings.Join(nsServerGroup.Servers, ", "),
|
||||||
|
domainsString,
|
||||||
|
enabled,
|
||||||
|
errorString,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||||
|
}
|
||||||
|
|
||||||
|
rosenpassEnabledStatus := "false"
|
||||||
|
if overview.RosenpassEnabled {
|
||||||
|
rosenpassEnabledStatus = "true"
|
||||||
|
if overview.RosenpassPermissive {
|
||||||
|
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||||
|
|
||||||
summary := fmt.Sprintf(
|
summary := fmt.Sprintf(
|
||||||
@@ -363,25 +530,33 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
|||||||
"CLI version: %s\n"+
|
"CLI version: %s\n"+
|
||||||
"Management: %s\n"+
|
"Management: %s\n"+
|
||||||
"Signal: %s\n"+
|
"Signal: %s\n"+
|
||||||
|
"Relays: %s\n"+
|
||||||
|
"Nameservers: %s\n"+
|
||||||
"FQDN: %s\n"+
|
"FQDN: %s\n"+
|
||||||
"NetBird IP: %s\n"+
|
"NetBird IP: %s\n"+
|
||||||
"Interface type: %s\n"+
|
"Interface type: %s\n"+
|
||||||
|
"Quantum resistance: %s\n"+
|
||||||
|
"Routes: %s\n"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
overview.DaemonVersion,
|
overview.DaemonVersion,
|
||||||
version.NetbirdVersion(),
|
version.NetbirdVersion(),
|
||||||
managementConnString,
|
managementConnString,
|
||||||
signalConnString,
|
signalConnString,
|
||||||
|
relaysString,
|
||||||
|
dnsServersString,
|
||||||
overview.FQDN,
|
overview.FQDN,
|
||||||
interfaceIP,
|
interfaceIP,
|
||||||
interfaceTypeString,
|
interfaceTypeString,
|
||||||
|
rosenpassEnabledStatus,
|
||||||
|
routes,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
)
|
)
|
||||||
return summary
|
return summary
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||||
parsedPeersString := parsePeers(overview.Peers)
|
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||||
summary := parseGeneralSummary(overview, true)
|
summary := parseGeneralSummary(overview, true, true, true)
|
||||||
|
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"Peers detail:"+
|
"Peers detail:"+
|
||||||
@@ -392,7 +567,7 @@ func parseToFullDetailSummary(overview statusOutputOverview) string {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parsePeers(peers peersStateOutput) string {
|
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
||||||
var (
|
var (
|
||||||
peersString = ""
|
peersString = ""
|
||||||
)
|
)
|
||||||
@@ -409,6 +584,48 @@ func parsePeers(peers peersStateOutput) string {
|
|||||||
remoteICE = peerState.IceCandidateType.Remote
|
remoteICE = peerState.IceCandidateType.Remote
|
||||||
}
|
}
|
||||||
|
|
||||||
|
localICEEndpoint := "-"
|
||||||
|
if peerState.IceCandidateEndpoint.Local != "" {
|
||||||
|
localICEEndpoint = peerState.IceCandidateEndpoint.Local
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteICEEndpoint := "-"
|
||||||
|
if peerState.IceCandidateEndpoint.Remote != "" {
|
||||||
|
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
||||||
|
}
|
||||||
|
lastStatusUpdate := "-"
|
||||||
|
if !peerState.LastStatusUpdate.IsZero() {
|
||||||
|
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
|
||||||
|
}
|
||||||
|
|
||||||
|
lastWireGuardHandshake := "-"
|
||||||
|
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
|
||||||
|
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
|
||||||
|
}
|
||||||
|
|
||||||
|
rosenpassEnabledStatus := "false"
|
||||||
|
if rosenpassEnabled {
|
||||||
|
if peerState.RosenpassEnabled {
|
||||||
|
rosenpassEnabledStatus = "true"
|
||||||
|
} else {
|
||||||
|
if rosenpassPermissive {
|
||||||
|
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
|
||||||
|
} else {
|
||||||
|
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if peerState.RosenpassEnabled {
|
||||||
|
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := "-"
|
||||||
|
if len(peerState.Routes) > 0 {
|
||||||
|
sort.Strings(peerState.Routes)
|
||||||
|
routes = strings.Join(peerState.Routes, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
peerString := fmt.Sprintf(
|
peerString := fmt.Sprintf(
|
||||||
"\n %s:\n"+
|
"\n %s:\n"+
|
||||||
" NetBird IP: %s\n"+
|
" NetBird IP: %s\n"+
|
||||||
@@ -418,7 +635,13 @@ func parsePeers(peers peersStateOutput) string {
|
|||||||
" Connection type: %s\n"+
|
" Connection type: %s\n"+
|
||||||
" Direct: %t\n"+
|
" Direct: %t\n"+
|
||||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||||
" Last connection update: %s\n",
|
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
||||||
|
" Last connection update: %s\n"+
|
||||||
|
" Last WireGuard handshake: %s\n"+
|
||||||
|
" Transfer status (received/sent) %s/%s\n"+
|
||||||
|
" Quantum resistance: %s\n"+
|
||||||
|
" Routes: %s\n"+
|
||||||
|
" Latency: %s\n",
|
||||||
peerState.FQDN,
|
peerState.FQDN,
|
||||||
peerState.IP,
|
peerState.IP,
|
||||||
peerState.PubKey,
|
peerState.PubKey,
|
||||||
@@ -427,7 +650,15 @@ func parsePeers(peers peersStateOutput) string {
|
|||||||
peerState.Direct,
|
peerState.Direct,
|
||||||
localICE,
|
localICE,
|
||||||
remoteICE,
|
remoteICE,
|
||||||
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
|
localICEEndpoint,
|
||||||
|
remoteICEEndpoint,
|
||||||
|
lastStatusUpdate,
|
||||||
|
lastWireGuardHandshake,
|
||||||
|
toIEC(peerState.TransferReceived),
|
||||||
|
toIEC(peerState.TransferSent),
|
||||||
|
rosenpassEnabledStatus,
|
||||||
|
routes,
|
||||||
|
peerState.Latency.String(),
|
||||||
)
|
)
|
||||||
|
|
||||||
peersString += peerString
|
peersString += peerString
|
||||||
@@ -467,3 +698,27 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
|||||||
|
|
||||||
return statusEval || ipEval || nameEval
|
return statusEval || ipEval || nameEval
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toIEC(b int64) string {
|
||||||
|
const unit = 1024
|
||||||
|
if b < unit {
|
||||||
|
return fmt.Sprintf("%d B", b)
|
||||||
|
}
|
||||||
|
div, exp := int64(unit), 0
|
||||||
|
for n := b / unit; n >= unit; n /= unit {
|
||||||
|
div *= unit
|
||||||
|
exp++
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.1f %ciB",
|
||||||
|
float64(b)/float64(div), "KMGTPE"[exp])
|
||||||
|
}
|
||||||
|
|
||||||
|
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
||||||
|
count := 0
|
||||||
|
for _, server := range dnsServers {
|
||||||
|
if server.Enabled {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
@@ -25,41 +29,95 @@ var resp = &proto.StatusResponse{
|
|||||||
FullStatus: &proto.FullStatus{
|
FullStatus: &proto.FullStatus{
|
||||||
Peers: []*proto.PeerState{
|
Peers: []*proto.PeerState{
|
||||||
{
|
{
|
||||||
IP: "192.168.178.101",
|
IP: "192.168.178.101",
|
||||||
PubKey: "Pubkey1",
|
PubKey: "Pubkey1",
|
||||||
Fqdn: "peer-1.awesome-domain.com",
|
Fqdn: "peer-1.awesome-domain.com",
|
||||||
ConnStatus: "Connected",
|
ConnStatus: "Connected",
|
||||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||||
Relayed: false,
|
Relayed: false,
|
||||||
Direct: true,
|
Direct: true,
|
||||||
LocalIceCandidateType: "",
|
LocalIceCandidateType: "",
|
||||||
RemoteIceCandidateType: "",
|
RemoteIceCandidateType: "",
|
||||||
|
LocalIceCandidateEndpoint: "",
|
||||||
|
RemoteIceCandidateEndpoint: "",
|
||||||
|
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||||
|
BytesRx: 200,
|
||||||
|
BytesTx: 100,
|
||||||
|
Routes: []string{
|
||||||
|
"10.1.0.0/24",
|
||||||
|
},
|
||||||
|
Latency: durationpb.New(time.Duration(10000000)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
IP: "192.168.178.102",
|
IP: "192.168.178.102",
|
||||||
PubKey: "Pubkey2",
|
PubKey: "Pubkey2",
|
||||||
Fqdn: "peer-2.awesome-domain.com",
|
Fqdn: "peer-2.awesome-domain.com",
|
||||||
ConnStatus: "Connected",
|
ConnStatus: "Connected",
|
||||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||||
Relayed: true,
|
Relayed: true,
|
||||||
Direct: false,
|
Direct: false,
|
||||||
LocalIceCandidateType: "relay",
|
LocalIceCandidateType: "relay",
|
||||||
RemoteIceCandidateType: "prflx",
|
RemoteIceCandidateType: "prflx",
|
||||||
|
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
||||||
|
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
|
||||||
|
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
|
||||||
|
BytesRx: 2000,
|
||||||
|
BytesTx: 1000,
|
||||||
|
Latency: durationpb.New(time.Duration(10000000)),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ManagementState: &proto.ManagementState{
|
ManagementState: &proto.ManagementState{
|
||||||
URL: "my-awesome-management.com:443",
|
URL: "my-awesome-management.com:443",
|
||||||
Connected: true,
|
Connected: true,
|
||||||
|
Error: "",
|
||||||
},
|
},
|
||||||
SignalState: &proto.SignalState{
|
SignalState: &proto.SignalState{
|
||||||
URL: "my-awesome-signal.com:443",
|
URL: "my-awesome-signal.com:443",
|
||||||
Connected: true,
|
Connected: true,
|
||||||
|
Error: "",
|
||||||
|
},
|
||||||
|
Relays: []*proto.RelayState{
|
||||||
|
{
|
||||||
|
URI: "stun:my-awesome-stun.com:3478",
|
||||||
|
Available: true,
|
||||||
|
Error: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||||
|
Available: false,
|
||||||
|
Error: "context: deadline exceeded",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
LocalPeerState: &proto.LocalPeerState{
|
LocalPeerState: &proto.LocalPeerState{
|
||||||
IP: "192.168.178.100/16",
|
IP: "192.168.178.100/16",
|
||||||
PubKey: "Some-Pub-Key",
|
PubKey: "Some-Pub-Key",
|
||||||
KernelInterface: true,
|
KernelInterface: true,
|
||||||
Fqdn: "some-localhost.awesome-domain.com",
|
Fqdn: "some-localhost.awesome-domain.com",
|
||||||
|
Routes: []string{
|
||||||
|
"10.10.0.0/24",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DnsServers: []*proto.NSGroupState{
|
||||||
|
{
|
||||||
|
Servers: []string{
|
||||||
|
"8.8.8.8:53",
|
||||||
|
},
|
||||||
|
Domains: nil,
|
||||||
|
Enabled: true,
|
||||||
|
Error: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Servers: []string{
|
||||||
|
"1.1.1.1:53",
|
||||||
|
"2.2.2.2:53",
|
||||||
|
},
|
||||||
|
Domains: []string{
|
||||||
|
"example.com",
|
||||||
|
"example.net",
|
||||||
|
},
|
||||||
|
Enabled: false,
|
||||||
|
Error: "timeout",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
DaemonVersion: "0.14.1",
|
DaemonVersion: "0.14.1",
|
||||||
@@ -82,6 +140,17 @@ var overview = statusOutputOverview{
|
|||||||
Local: "",
|
Local: "",
|
||||||
Remote: "",
|
Remote: "",
|
||||||
},
|
},
|
||||||
|
IceCandidateEndpoint: iceCandidateType{
|
||||||
|
Local: "",
|
||||||
|
Remote: "",
|
||||||
|
},
|
||||||
|
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
||||||
|
TransferReceived: 200,
|
||||||
|
TransferSent: 100,
|
||||||
|
Routes: []string{
|
||||||
|
"10.1.0.0/24",
|
||||||
|
},
|
||||||
|
Latency: time.Duration(10000000),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
IP: "192.168.178.102",
|
IP: "192.168.178.102",
|
||||||
@@ -95,6 +164,14 @@ var overview = statusOutputOverview{
|
|||||||
Local: "relay",
|
Local: "relay",
|
||||||
Remote: "prflx",
|
Remote: "prflx",
|
||||||
},
|
},
|
||||||
|
IceCandidateEndpoint: iceCandidateType{
|
||||||
|
Local: "10.0.0.1:10001",
|
||||||
|
Remote: "10.0.10.1:10002",
|
||||||
|
},
|
||||||
|
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
|
||||||
|
TransferReceived: 2000,
|
||||||
|
TransferSent: 1000,
|
||||||
|
Latency: time.Duration(10000000),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -103,15 +180,58 @@ var overview = statusOutputOverview{
|
|||||||
ManagementState: managementStateOutput{
|
ManagementState: managementStateOutput{
|
||||||
URL: "my-awesome-management.com:443",
|
URL: "my-awesome-management.com:443",
|
||||||
Connected: true,
|
Connected: true,
|
||||||
|
Error: "",
|
||||||
},
|
},
|
||||||
SignalState: signalStateOutput{
|
SignalState: signalStateOutput{
|
||||||
URL: "my-awesome-signal.com:443",
|
URL: "my-awesome-signal.com:443",
|
||||||
Connected: true,
|
Connected: true,
|
||||||
|
Error: "",
|
||||||
|
},
|
||||||
|
Relays: relayStateOutput{
|
||||||
|
Total: 2,
|
||||||
|
Available: 1,
|
||||||
|
Details: []relayStateOutputDetail{
|
||||||
|
{
|
||||||
|
URI: "stun:my-awesome-stun.com:3478",
|
||||||
|
Available: true,
|
||||||
|
Error: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||||
|
Available: false,
|
||||||
|
Error: "context: deadline exceeded",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
IP: "192.168.178.100/16",
|
IP: "192.168.178.100/16",
|
||||||
PubKey: "Some-Pub-Key",
|
PubKey: "Some-Pub-Key",
|
||||||
KernelInterface: true,
|
KernelInterface: true,
|
||||||
FQDN: "some-localhost.awesome-domain.com",
|
FQDN: "some-localhost.awesome-domain.com",
|
||||||
|
NSServerGroups: []nsServerGroupStateOutput{
|
||||||
|
{
|
||||||
|
Servers: []string{
|
||||||
|
"8.8.8.8:53",
|
||||||
|
},
|
||||||
|
Domains: nil,
|
||||||
|
Enabled: true,
|
||||||
|
Error: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Servers: []string{
|
||||||
|
"1.1.1.1:53",
|
||||||
|
"2.2.2.2:53",
|
||||||
|
},
|
||||||
|
Domains: []string{
|
||||||
|
"example.com",
|
||||||
|
"example.net",
|
||||||
|
},
|
||||||
|
Enabled: false,
|
||||||
|
Error: "timeout",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Routes: []string{
|
||||||
|
"10.10.0.0/24",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||||
@@ -145,107 +265,223 @@ func TestSortingOfPeers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParsingToJSON(t *testing.T) {
|
func TestParsingToJSON(t *testing.T) {
|
||||||
json, _ := parseToJSON(overview)
|
jsonString, _ := parseToJSON(overview)
|
||||||
|
|
||||||
//@formatter:off
|
//@formatter:off
|
||||||
expectedJSON := "{\"" +
|
expectedJSONString := `
|
||||||
"peers\":" +
|
{
|
||||||
"{" +
|
"peers": {
|
||||||
"\"total\":2," +
|
"total": 2,
|
||||||
"\"connected\":2," +
|
"connected": 2,
|
||||||
"\"details\":" +
|
"details": [
|
||||||
"[" +
|
{
|
||||||
"{" +
|
"fqdn": "peer-1.awesome-domain.com",
|
||||||
"\"fqdn\":\"peer-1.awesome-domain.com\"," +
|
"netbirdIp": "192.168.178.101",
|
||||||
"\"netbirdIp\":\"192.168.178.101\"," +
|
"publicKey": "Pubkey1",
|
||||||
"\"publicKey\":\"Pubkey1\"," +
|
"status": "Connected",
|
||||||
"\"status\":\"Connected\"," +
|
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
||||||
"\"lastStatusUpdate\":\"2001-01-01T01:01:01Z\"," +
|
"connectionType": "P2P",
|
||||||
"\"connectionType\":\"P2P\"," +
|
"direct": true,
|
||||||
"\"direct\":true," +
|
"iceCandidateType": {
|
||||||
"\"iceCandidateType\":" +
|
"local": "",
|
||||||
"{" +
|
"remote": ""
|
||||||
"\"local\":\"\"," +
|
},
|
||||||
"\"remote\":\"\"" +
|
"iceCandidateEndpoint": {
|
||||||
"}" +
|
"local": "",
|
||||||
"}," +
|
"remote": ""
|
||||||
"{" +
|
},
|
||||||
"\"fqdn\":\"peer-2.awesome-domain.com\"," +
|
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
||||||
"\"netbirdIp\":\"192.168.178.102\"," +
|
"transferReceived": 200,
|
||||||
"\"publicKey\":\"Pubkey2\"," +
|
"transferSent": 100,
|
||||||
"\"status\":\"Connected\"," +
|
"latency": 10000000,
|
||||||
"\"lastStatusUpdate\":\"2002-02-02T02:02:02Z\"," +
|
"quantumResistance": false,
|
||||||
"\"connectionType\":\"Relayed\"," +
|
"routes": [
|
||||||
"\"direct\":false," +
|
"10.1.0.0/24"
|
||||||
"\"iceCandidateType\":" +
|
]
|
||||||
"{" +
|
},
|
||||||
"\"local\":\"relay\"," +
|
{
|
||||||
"\"remote\":\"prflx\"" +
|
"fqdn": "peer-2.awesome-domain.com",
|
||||||
"}" +
|
"netbirdIp": "192.168.178.102",
|
||||||
"}" +
|
"publicKey": "Pubkey2",
|
||||||
"]" +
|
"status": "Connected",
|
||||||
"}," +
|
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
||||||
"\"cliVersion\":\"development\"," +
|
"connectionType": "Relayed",
|
||||||
"\"daemonVersion\":\"0.14.1\"," +
|
"direct": false,
|
||||||
"\"management\":" +
|
"iceCandidateType": {
|
||||||
"{" +
|
"local": "relay",
|
||||||
"\"url\":\"my-awesome-management.com:443\"," +
|
"remote": "prflx"
|
||||||
"\"connected\":true" +
|
},
|
||||||
"}," +
|
"iceCandidateEndpoint": {
|
||||||
"\"signal\":" +
|
"local": "10.0.0.1:10001",
|
||||||
"{\"" +
|
"remote": "10.0.10.1:10002"
|
||||||
"url\":\"my-awesome-signal.com:443\"," +
|
},
|
||||||
"\"connected\":true" +
|
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
||||||
"}," +
|
"transferReceived": 2000,
|
||||||
"\"netbirdIp\":\"192.168.178.100/16\"," +
|
"transferSent": 1000,
|
||||||
"\"publicKey\":\"Some-Pub-Key\"," +
|
"latency": 10000000,
|
||||||
"\"usesKernelInterface\":true," +
|
"quantumResistance": false,
|
||||||
"\"fqdn\":\"some-localhost.awesome-domain.com\"" +
|
"routes": null
|
||||||
"}"
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"cliVersion": "development",
|
||||||
|
"daemonVersion": "0.14.1",
|
||||||
|
"management": {
|
||||||
|
"url": "my-awesome-management.com:443",
|
||||||
|
"connected": true,
|
||||||
|
"error": ""
|
||||||
|
},
|
||||||
|
"signal": {
|
||||||
|
"url": "my-awesome-signal.com:443",
|
||||||
|
"connected": true,
|
||||||
|
"error": ""
|
||||||
|
},
|
||||||
|
"relays": {
|
||||||
|
"total": 2,
|
||||||
|
"available": 1,
|
||||||
|
"details": [
|
||||||
|
{
|
||||||
|
"uri": "stun:my-awesome-stun.com:3478",
|
||||||
|
"available": true,
|
||||||
|
"error": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
|
||||||
|
"available": false,
|
||||||
|
"error": "context: deadline exceeded"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"netbirdIp": "192.168.178.100/16",
|
||||||
|
"publicKey": "Some-Pub-Key",
|
||||||
|
"usesKernelInterface": true,
|
||||||
|
"fqdn": "some-localhost.awesome-domain.com",
|
||||||
|
"quantumResistance": false,
|
||||||
|
"quantumResistancePermissive": false,
|
||||||
|
"routes": [
|
||||||
|
"10.10.0.0/24"
|
||||||
|
],
|
||||||
|
"dnsServers": [
|
||||||
|
{
|
||||||
|
"servers": [
|
||||||
|
"8.8.8.8:53"
|
||||||
|
],
|
||||||
|
"domains": null,
|
||||||
|
"enabled": true,
|
||||||
|
"error": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"servers": [
|
||||||
|
"1.1.1.1:53",
|
||||||
|
"2.2.2.2:53"
|
||||||
|
],
|
||||||
|
"domains": [
|
||||||
|
"example.com",
|
||||||
|
"example.net"
|
||||||
|
],
|
||||||
|
"enabled": false,
|
||||||
|
"error": "timeout"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
|
|
||||||
assert.Equal(t, expectedJSON, json)
|
var expectedJSON bytes.Buffer
|
||||||
|
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
|
||||||
|
|
||||||
|
assert.Equal(t, expectedJSON.String(), jsonString)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParsingToYAML(t *testing.T) {
|
func TestParsingToYAML(t *testing.T) {
|
||||||
yaml, _ := parseToYAML(overview)
|
yaml, _ := parseToYAML(overview)
|
||||||
|
|
||||||
expectedYAML := "peers:\n" +
|
expectedYAML :=
|
||||||
" total: 2\n" +
|
`peers:
|
||||||
" connected: 2\n" +
|
total: 2
|
||||||
" details:\n" +
|
connected: 2
|
||||||
" - fqdn: peer-1.awesome-domain.com\n" +
|
details:
|
||||||
" netbirdIp: 192.168.178.101\n" +
|
- fqdn: peer-1.awesome-domain.com
|
||||||
" publicKey: Pubkey1\n" +
|
netbirdIp: 192.168.178.101
|
||||||
" status: Connected\n" +
|
publicKey: Pubkey1
|
||||||
" lastStatusUpdate: 2001-01-01T01:01:01Z\n" +
|
status: Connected
|
||||||
" connectionType: P2P\n" +
|
lastStatusUpdate: 2001-01-01T01:01:01Z
|
||||||
" direct: true\n" +
|
connectionType: P2P
|
||||||
" iceCandidateType:\n" +
|
direct: true
|
||||||
" local: \"\"\n" +
|
iceCandidateType:
|
||||||
" remote: \"\"\n" +
|
local: ""
|
||||||
" - fqdn: peer-2.awesome-domain.com\n" +
|
remote: ""
|
||||||
" netbirdIp: 192.168.178.102\n" +
|
iceCandidateEndpoint:
|
||||||
" publicKey: Pubkey2\n" +
|
local: ""
|
||||||
" status: Connected\n" +
|
remote: ""
|
||||||
" lastStatusUpdate: 2002-02-02T02:02:02Z\n" +
|
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
||||||
" connectionType: Relayed\n" +
|
transferReceived: 200
|
||||||
" direct: false\n" +
|
transferSent: 100
|
||||||
" iceCandidateType:\n" +
|
latency: 10ms
|
||||||
" local: relay\n" +
|
quantumResistance: false
|
||||||
" remote: prflx\n" +
|
routes:
|
||||||
"cliVersion: development\n" +
|
- 10.1.0.0/24
|
||||||
"daemonVersion: 0.14.1\n" +
|
- fqdn: peer-2.awesome-domain.com
|
||||||
"management:\n" +
|
netbirdIp: 192.168.178.102
|
||||||
" url: my-awesome-management.com:443\n" +
|
publicKey: Pubkey2
|
||||||
" connected: true\n" +
|
status: Connected
|
||||||
"signal:\n" +
|
lastStatusUpdate: 2002-02-02T02:02:02Z
|
||||||
" url: my-awesome-signal.com:443\n" +
|
connectionType: Relayed
|
||||||
" connected: true\n" +
|
direct: false
|
||||||
"netbirdIp: 192.168.178.100/16\n" +
|
iceCandidateType:
|
||||||
"publicKey: Some-Pub-Key\n" +
|
local: relay
|
||||||
"usesKernelInterface: true\n" +
|
remote: prflx
|
||||||
"fqdn: some-localhost.awesome-domain.com\n"
|
iceCandidateEndpoint:
|
||||||
|
local: 10.0.0.1:10001
|
||||||
|
remote: 10.0.10.1:10002
|
||||||
|
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
||||||
|
transferReceived: 2000
|
||||||
|
transferSent: 1000
|
||||||
|
latency: 10ms
|
||||||
|
quantumResistance: false
|
||||||
|
routes: []
|
||||||
|
cliVersion: development
|
||||||
|
daemonVersion: 0.14.1
|
||||||
|
management:
|
||||||
|
url: my-awesome-management.com:443
|
||||||
|
connected: true
|
||||||
|
error: ""
|
||||||
|
signal:
|
||||||
|
url: my-awesome-signal.com:443
|
||||||
|
connected: true
|
||||||
|
error: ""
|
||||||
|
relays:
|
||||||
|
total: 2
|
||||||
|
available: 1
|
||||||
|
details:
|
||||||
|
- uri: stun:my-awesome-stun.com:3478
|
||||||
|
available: true
|
||||||
|
error: ""
|
||||||
|
- uri: turns:my-awesome-turn.com:443?transport=tcp
|
||||||
|
available: false
|
||||||
|
error: 'context: deadline exceeded'
|
||||||
|
netbirdIp: 192.168.178.100/16
|
||||||
|
publicKey: Some-Pub-Key
|
||||||
|
usesKernelInterface: true
|
||||||
|
fqdn: some-localhost.awesome-domain.com
|
||||||
|
quantumResistance: false
|
||||||
|
quantumResistancePermissive: false
|
||||||
|
routes:
|
||||||
|
- 10.10.0.0/24
|
||||||
|
dnsServers:
|
||||||
|
- servers:
|
||||||
|
- 8.8.8.8:53
|
||||||
|
domains: []
|
||||||
|
enabled: true
|
||||||
|
error: ""
|
||||||
|
- servers:
|
||||||
|
- 1.1.1.1:53
|
||||||
|
- 2.2.2.2:53
|
||||||
|
domains:
|
||||||
|
- example.com
|
||||||
|
- example.net
|
||||||
|
enabled: false
|
||||||
|
error: timeout
|
||||||
|
`
|
||||||
|
|
||||||
assert.Equal(t, expectedYAML, yaml)
|
assert.Equal(t, expectedYAML, yaml)
|
||||||
}
|
}
|
||||||
@@ -253,50 +489,78 @@ func TestParsingToYAML(t *testing.T) {
|
|||||||
func TestParsingToDetail(t *testing.T) {
|
func TestParsingToDetail(t *testing.T) {
|
||||||
detail := parseToFullDetailSummary(overview)
|
detail := parseToFullDetailSummary(overview)
|
||||||
|
|
||||||
expectedDetail := "Peers detail:\n" +
|
expectedDetail :=
|
||||||
" peer-1.awesome-domain.com:\n" +
|
`Peers detail:
|
||||||
" NetBird IP: 192.168.178.101\n" +
|
peer-1.awesome-domain.com:
|
||||||
" Public key: Pubkey1\n" +
|
NetBird IP: 192.168.178.101
|
||||||
" Status: Connected\n" +
|
Public key: Pubkey1
|
||||||
" -- detail --\n" +
|
Status: Connected
|
||||||
" Connection type: P2P\n" +
|
-- detail --
|
||||||
" Direct: true\n" +
|
Connection type: P2P
|
||||||
" ICE candidate (Local/Remote): -/-\n" +
|
Direct: true
|
||||||
" Last connection update: 2001-01-01 01:01:01\n" +
|
ICE candidate (Local/Remote): -/-
|
||||||
"\n" +
|
ICE candidate endpoints (Local/Remote): -/-
|
||||||
" peer-2.awesome-domain.com:\n" +
|
Last connection update: 2001-01-01 01:01:01
|
||||||
" NetBird IP: 192.168.178.102\n" +
|
Last WireGuard handshake: 2001-01-01 01:01:02
|
||||||
" Public key: Pubkey2\n" +
|
Transfer status (received/sent) 200 B/100 B
|
||||||
" Status: Connected\n" +
|
Quantum resistance: false
|
||||||
" -- detail --\n" +
|
Routes: 10.1.0.0/24
|
||||||
" Connection type: Relayed\n" +
|
Latency: 10ms
|
||||||
" Direct: false\n" +
|
|
||||||
" ICE candidate (Local/Remote): relay/prflx\n" +
|
peer-2.awesome-domain.com:
|
||||||
" Last connection update: 2002-02-02 02:02:02\n" +
|
NetBird IP: 192.168.178.102
|
||||||
"\n" +
|
Public key: Pubkey2
|
||||||
"Daemon version: 0.14.1\n" +
|
Status: Connected
|
||||||
"CLI version: development\n" +
|
-- detail --
|
||||||
"Management: Connected to my-awesome-management.com:443\n" +
|
Connection type: Relayed
|
||||||
"Signal: Connected to my-awesome-signal.com:443\n" +
|
Direct: false
|
||||||
"FQDN: some-localhost.awesome-domain.com\n" +
|
ICE candidate (Local/Remote): relay/prflx
|
||||||
"NetBird IP: 192.168.178.100/16\n" +
|
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||||
"Interface type: Kernel\n" +
|
Last connection update: 2002-02-02 02:02:02
|
||||||
"Peers count: 2/2 Connected\n"
|
Last WireGuard handshake: 2002-02-02 02:02:03
|
||||||
|
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||||
|
Quantum resistance: false
|
||||||
|
Routes: -
|
||||||
|
Latency: 10ms
|
||||||
|
|
||||||
|
Daemon version: 0.14.1
|
||||||
|
CLI version: development
|
||||||
|
Management: Connected to my-awesome-management.com:443
|
||||||
|
Signal: Connected to my-awesome-signal.com:443
|
||||||
|
Relays:
|
||||||
|
[stun:my-awesome-stun.com:3478] is Available
|
||||||
|
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
|
||||||
|
Nameservers:
|
||||||
|
[8.8.8.8:53] for [.] is Available
|
||||||
|
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
|
||||||
|
FQDN: some-localhost.awesome-domain.com
|
||||||
|
NetBird IP: 192.168.178.100/16
|
||||||
|
Interface type: Kernel
|
||||||
|
Quantum resistance: false
|
||||||
|
Routes: 10.10.0.0/24
|
||||||
|
Peers count: 2/2 Connected
|
||||||
|
`
|
||||||
|
|
||||||
assert.Equal(t, expectedDetail, detail)
|
assert.Equal(t, expectedDetail, detail)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParsingToShortVersion(t *testing.T) {
|
func TestParsingToShortVersion(t *testing.T) {
|
||||||
shortVersion := parseGeneralSummary(overview, false)
|
shortVersion := parseGeneralSummary(overview, false, false, false)
|
||||||
|
|
||||||
expectedString := "Daemon version: 0.14.1\n" +
|
expectedString :=
|
||||||
"CLI version: development\n" +
|
`Daemon version: 0.14.1
|
||||||
"Management: Connected\n" +
|
CLI version: development
|
||||||
"Signal: Connected\n" +
|
Management: Connected
|
||||||
"FQDN: some-localhost.awesome-domain.com\n" +
|
Signal: Connected
|
||||||
"NetBird IP: 192.168.178.100/16\n" +
|
Relays: 1/2 Available
|
||||||
"Interface type: Kernel\n" +
|
Nameservers: 1/2 Available
|
||||||
"Peers count: 2/2 Connected\n"
|
FQDN: some-localhost.awesome-domain.com
|
||||||
|
NetBird IP: 192.168.178.100/16
|
||||||
|
Interface type: Kernel
|
||||||
|
Quantum resistance: false
|
||||||
|
Routes: 10.10.0.0/24
|
||||||
|
Peers count: 2/2 Connected
|
||||||
|
`
|
||||||
|
|
||||||
assert.Equal(t, expectedString, shortVersion)
|
assert.Equal(t, expectedString, shortVersion)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
@@ -78,8 +79,8 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
iv, _ := integrations.NewIntegratedValidator(eventStore)
|
||||||
eventStore, false)
|
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||||
|
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||||
}
|
}
|
||||||
|
|
||||||
func upFunc(cmd *cobra.Command, args []string) error {
|
func upFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -83,17 +84,26 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ic := internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
NATExternalIPs: natExternalIPs,
|
NATExternalIPs: natExternalIPs,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
|
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||||
ic.RosenpassEnabled = &rosenpassEnabled
|
ic.RosenpassEnabled = &rosenpassEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||||
|
ic.RosenpassPermissive = &rosenpassPermissive
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
|
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -110,6 +120,18 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.PreSharedKey = &preSharedKey
|
ic.PreSharedKey = &preSharedKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
|
ic.DisableAutoConnect = &autoConnectDisabled
|
||||||
|
|
||||||
|
if autoConnectDisabled {
|
||||||
|
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !autoConnectDisabled {
|
||||||
|
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(ic)
|
config, err := internal.UpdateOrCreateConfig(ic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
@@ -129,7 +151,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
|
||||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -163,7 +184,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: setupKey,
|
SetupKey: setupKey,
|
||||||
PreSharedKey: preSharedKey,
|
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
NatExternalIPs: natExternalIPs,
|
NatExternalIPs: natExternalIPs,
|
||||||
@@ -171,12 +191,29 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
|
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||||
|
}
|
||||||
|
|
||||||
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
|
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||||
loginRequest.RosenpassEnabled = &rosenpassEnabled
|
loginRequest.RosenpassEnabled = &rosenpassEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||||
|
loginRequest.RosenpassPermissive = &rosenpassPermissive
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
|
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
|
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ type HTTPClient interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
||||||
type AuthFlowInfo struct {
|
type AuthFlowInfo struct { //nolint:revive
|
||||||
DeviceCode string `json:"device_code"`
|
DeviceCode string `json:"device_code"`
|
||||||
UserCode string `json:"user_code"`
|
UserCode string `json:"user_code"`
|
||||||
VerificationURI string `json:"verification_uri"`
|
VerificationURI string `json:"verification_uri"`
|
||||||
|
|||||||
@@ -30,20 +30,26 @@ const (
|
|||||||
DefaultAdminURL = "https://app.netbird.io:443"
|
DefaultAdminURL = "https://app.netbird.io:443"
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
var defaultInterfaceBlacklist = []string{
|
||||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
|
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||||
|
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||||
|
}
|
||||||
|
|
||||||
// ConfigInput carries configuration changes to the client
|
// ConfigInput carries configuration changes to the client
|
||||||
type ConfigInput struct {
|
type ConfigInput struct {
|
||||||
ManagementURL string
|
ManagementURL string
|
||||||
AdminURL string
|
AdminURL string
|
||||||
ConfigPath string
|
ConfigPath string
|
||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
NATExternalIPs []string
|
ServerSSHAllowed *bool
|
||||||
CustomDNSAddress []byte
|
NATExternalIPs []string
|
||||||
RosenpassEnabled *bool
|
CustomDNSAddress []byte
|
||||||
InterfaceName *string
|
RosenpassEnabled *bool
|
||||||
WireguardPort *int
|
RosenpassPermissive *bool
|
||||||
|
InterfaceName *string
|
||||||
|
WireguardPort *int
|
||||||
|
DisableAutoConnect *bool
|
||||||
|
ExtraIFaceBlackList []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@@ -58,6 +64,8 @@ type Config struct {
|
|||||||
IFaceBlackList []string
|
IFaceBlackList []string
|
||||||
DisableIPv6Discovery bool
|
DisableIPv6Discovery bool
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
|
RosenpassPermissive bool
|
||||||
|
ServerSSHAllowed *bool
|
||||||
// SSHKey is a private SSH key in a PEM format
|
// SSHKey is a private SSH key in a PEM format
|
||||||
SSHKey string
|
SSHKey string
|
||||||
|
|
||||||
@@ -79,6 +87,10 @@ type Config struct {
|
|||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
// CustomDNSAddress sets the DNS resolver listening address in format ip:port
|
// CustomDNSAddress sets the DNS resolver listening address in format ip:port
|
||||||
CustomDNSAddress string
|
CustomDNSAddress string
|
||||||
|
|
||||||
|
// DisableAutoConnect determines whether the client should not start with the service
|
||||||
|
// it's set to false by default due to backwards compatibility
|
||||||
|
DisableAutoConnect bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
@@ -88,6 +100,7 @@ func ReadConfig(configPath string) (*Config, error) {
|
|||||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +165,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
DisableIPv6Discovery: false,
|
DisableIPv6Discovery: false,
|
||||||
NATExternalIPs: input.NATExternalIPs,
|
NATExternalIPs: input.NATExternalIPs,
|
||||||
CustomDNSAddress: string(input.CustomDNSAddress),
|
CustomDNSAddress: string(input.CustomDNSAddress),
|
||||||
|
ServerSSHAllowed: util.False(),
|
||||||
|
DisableAutoConnect: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||||
@@ -186,6 +201,14 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
config.RosenpassEnabled = *input.RosenpassEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.RosenpassPermissive != nil {
|
||||||
|
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.ServerSSHAllowed != nil {
|
||||||
|
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||||
|
}
|
||||||
|
|
||||||
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -200,7 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
config.AdminURL = newURL
|
config.AdminURL = newURL
|
||||||
}
|
}
|
||||||
|
|
||||||
config.IFaceBlackList = defaultInterfaceBlacklist
|
// nolint:gocritic
|
||||||
|
config.IFaceBlackList = append(defaultInterfaceBlacklist, input.ExtraIFaceBlackList...)
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,6 +304,33 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
refresh = true
|
refresh = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.RosenpassPermissive != nil {
|
||||||
|
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||||
|
refresh = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.DisableAutoConnect != nil {
|
||||||
|
config.DisableAutoConnect = *input.DisableAutoConnect
|
||||||
|
refresh = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.ServerSSHAllowed != nil {
|
||||||
|
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||||
|
refresh = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ServerSSHAllowed == nil {
|
||||||
|
config.ServerSSHAllowed = util.True()
|
||||||
|
refresh = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(input.ExtraIFaceBlackList) > 0 {
|
||||||
|
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
|
||||||
|
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
|
||||||
|
refresh = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if refresh {
|
if refresh {
|
||||||
// since we have new management URL, we need to update config file
|
// since we have new management URL, we need to update config file
|
||||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||||
@@ -344,7 +395,6 @@ func configFileIsExists(path string) bool {
|
|||||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||||
// The check is performed only for the NetBird's managed version.
|
// The check is performed only for the NetBird's managed version.
|
||||||
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||||
|
|
||||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ func TestGetConfig(t *testing.T) {
|
|||||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -86,6 +85,26 @@ func TestGetConfig(t *testing.T) {
|
|||||||
assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL)
|
assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtraIFaceBlackList(t *testing.T) {
|
||||||
|
extraIFaceBlackList := []string{"eth1"}
|
||||||
|
path := filepath.Join(t.TempDir(), "config.json")
|
||||||
|
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||||
|
ConfigPath: path,
|
||||||
|
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Contains(t, config.IFaceBlackList, "eth1")
|
||||||
|
readConf, err := util.ReadJson(path, config)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Contains(t, readConf.(*Config).IFaceBlackList, "eth1")
|
||||||
|
}
|
||||||
|
|
||||||
func TestHiddenPreSharedKey(t *testing.T) {
|
func TestHiddenPreSharedKey(t *testing.T) {
|
||||||
hidden := "**********"
|
hidden := "**********"
|
||||||
samplePreSharedKey := "mysecretpresharedkey"
|
samplePreSharedKey := "mysecretpresharedkey"
|
||||||
@@ -111,7 +130,6 @@ func TestHiddenPreSharedKey(t *testing.T) {
|
|||||||
ConfigPath: cfgFile,
|
ConfigPath: cfgFile,
|
||||||
PreSharedKey: tt.preSharedKey,
|
PreSharedKey: tt.preSharedKey,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get cfg: %s", err)
|
t.Fatalf("failed to get cfg: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -22,16 +25,39 @@ import (
|
|||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||||
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunClientWithProbes runs the client's main logic with probes attached
|
||||||
|
func RunClientWithProbes(
|
||||||
|
ctx context.Context,
|
||||||
|
config *Config,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
mgmProbe *Probe,
|
||||||
|
signalProbe *Probe,
|
||||||
|
relayProbe *Probe,
|
||||||
|
wgProbe *Probe,
|
||||||
|
) error {
|
||||||
|
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunClientMobile with main logic on mobile system
|
// RunClientMobile with main logic on mobile system
|
||||||
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
func RunClientMobile(
|
||||||
|
ctx context.Context,
|
||||||
|
config *Config,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
tunAdapter iface.TunAdapter,
|
||||||
|
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||||
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
|
dnsAddresses []string,
|
||||||
|
dnsReadyListener dns.ReadyListener,
|
||||||
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
TunAdapter: tunAdapter,
|
TunAdapter: tunAdapter,
|
||||||
@@ -40,20 +66,48 @@ func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.S
|
|||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
}
|
}
|
||||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunClientiOS(ctx context.Context, config *Config, statusRecorder *peer.Status, fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager) error {
|
func RunClientiOS(
|
||||||
|
ctx context.Context,
|
||||||
|
config *Config,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
fileDescriptor int32,
|
||||||
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
|
dnsManager dns.IosDnsManager,
|
||||||
|
) error {
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
FileDescriptor: fileDescriptor,
|
FileDescriptor: fileDescriptor,
|
||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
}
|
}
|
||||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
func runClient(
|
||||||
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
|
ctx context.Context,
|
||||||
|
config *Config,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
mobileDependency MobileDependency,
|
||||||
|
mgmProbe *Probe,
|
||||||
|
signalProbe *Probe,
|
||||||
|
relayProbe *Probe,
|
||||||
|
wgProbe *Probe,
|
||||||
|
) error {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||||
|
|
||||||
|
// Check if client was not shut down in a clean way and restore DNS config if required.
|
||||||
|
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||||
|
if err := dns.CheckUncleanShutdown(config.WgIface); err != nil {
|
||||||
|
log.Errorf("checking unclean shutdown error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
@@ -103,7 +157,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
|
|
||||||
engineCtx, cancel := context.WithCancel(ctx)
|
engineCtx, cancel := context.WithCancel(ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
statusRecorder.MarkManagementDisconnected()
|
statusRecorder.MarkManagementDisconnected(state.err)
|
||||||
statusRecorder.CleanLocalPeerState()
|
statusRecorder.CleanLocalPeerState()
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -152,8 +206,10 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
|
|
||||||
statusRecorder.UpdateSignalAddress(signalURL)
|
statusRecorder.UpdateSignalAddress(signalURL)
|
||||||
|
|
||||||
statusRecorder.MarkSignalDisconnected()
|
statusRecorder.MarkSignalDisconnected(nil)
|
||||||
defer statusRecorder.MarkSignalDisconnected()
|
defer func() {
|
||||||
|
statusRecorder.MarkSignalDisconnected(state.err)
|
||||||
|
}()
|
||||||
|
|
||||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
||||||
@@ -181,7 +237,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||||
err = engine.Start()
|
err = engine.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
@@ -204,7 +260,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
|||||||
|
|
||||||
log.Info("stopped NetBird client")
|
log.Info("stopped NetBird client")
|
||||||
|
|
||||||
if _, err := state.Status(); err == ErrResetConnection {
|
if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,6 +292,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
NATExternalIPs: config.NATExternalIPs,
|
NATExternalIPs: config.NATExternalIPs,
|
||||||
CustomDNSAddress: config.CustomDNSAddress,
|
CustomDNSAddress: config.CustomDNSAddress,
|
||||||
RosenpassEnabled: config.RosenpassEnabled,
|
RosenpassEnabled: config.RosenpassEnabled,
|
||||||
|
RosenpassPermissive: config.RosenpassPermissive,
|
||||||
|
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const dbusDefaultFlag = 0
|
const dbusDefaultFlag = 0
|
||||||
@@ -14,6 +16,7 @@ const dbusDefaultFlag = 0
|
|||||||
func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
|
func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
|
||||||
obj, closeConn, err := getDbusObject(dest, path)
|
obj, closeConn, err := getDbusObject(dest, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Tracef("error getting dbus object: %s", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
@@ -21,14 +24,18 @@ func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
|
|||||||
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store()
|
if err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
|
||||||
return err == nil
|
log.Tracef("error calling dbus: %s", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDbusObject(dest string, path dbus.ObjectPath) (dbus.BusObject, func(), error) {
|
func getDbusObject(dest string, path dbus.ObjectPath) (dbus.BusObject, func(), error) {
|
||||||
conn, err := dbus.SystemBus()
|
conn, err := dbus.SystemBus()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, fmt.Errorf("get dbus: %w", err)
|
||||||
}
|
}
|
||||||
obj := conn.Object(dest, path)
|
obj := conn.Object(dest, path)
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,12 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@@ -23,12 +24,22 @@ const (
|
|||||||
fileMaxNumberOfSearchDomains = 6
|
fileMaxNumberOfSearchDomains = 6
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dnsFailoverTimeout = 4 * time.Second
|
||||||
|
dnsFailoverAttempts = 1
|
||||||
|
)
|
||||||
|
|
||||||
type fileConfigurator struct {
|
type fileConfigurator struct {
|
||||||
originalPerms os.FileMode
|
repair *repair
|
||||||
|
|
||||||
|
originalPerms os.FileMode
|
||||||
|
nbNameserverIP string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFileConfigurator() (hostManager, error) {
|
func newFileConfigurator() (hostManager, error) {
|
||||||
return &fileConfigurator{}, nil
|
fc := &fileConfigurator{}
|
||||||
|
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
|
||||||
|
return fc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) supportCustomPort() bool {
|
func (f *fileConfigurator) supportCustomPort() bool {
|
||||||
@@ -46,7 +57,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
if backupFileExist {
|
if backupFileExist {
|
||||||
err = f.restore()
|
err = f.restore()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %s", err)
|
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||||
@@ -55,66 +66,150 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
if !backupFileExist {
|
if !backupFileExist {
|
||||||
err = f.backup()
|
err = f.backup()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to backup the resolv.conf file")
|
return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
searchDomainList := searchDomains(config)
|
nbSearchDomains := searchDomains(config)
|
||||||
|
f.nbNameserverIP = config.ServerIP
|
||||||
|
|
||||||
originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
|
resolvConf, err := parseBackupResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)
|
f.repair.stopWatchFileChanges()
|
||||||
|
|
||||||
|
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
|
||||||
|
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
||||||
|
nameServers := generateNsList(nbNameserverIP, cfg)
|
||||||
|
|
||||||
|
options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
|
||||||
buf := prepareResolvConfContent(
|
buf := prepareResolvConfContent(
|
||||||
searchDomainList,
|
searchDomainList,
|
||||||
append([]string{config.ServerIP}, nameServers...),
|
nameServers,
|
||||||
others)
|
options)
|
||||||
|
|
||||||
log.Debugf("creating managed file %s", defaultResolvConfPath)
|
log.Debugf("creating managed file %s", defaultResolvConfPath)
|
||||||
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
restoreErr := f.restore()
|
restoreErr := f.restore()
|
||||||
if restoreErr != nil {
|
if restoreErr != nil {
|
||||||
log.Errorf("attempt to restore default file failed with error: %s", err)
|
log.Errorf("attempt to restore default file failed with error: %s", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err)
|
return fmt.Errorf("creating resolver file %s. Error: %w", defaultResolvConfPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
||||||
|
|
||||||
|
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
|
||||||
|
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
|
||||||
|
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) restoreHostDNS() error {
|
func (f *fileConfigurator) restoreHostDNS() error {
|
||||||
|
f.repair.stopWatchFileChanges()
|
||||||
return f.restore()
|
return f.restore()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) backup() error {
|
func (f *fileConfigurator) backup() error {
|
||||||
stats, err := os.Stat(defaultResolvConfPath)
|
stats, err := os.Stat(defaultResolvConfPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an error while checking stats for %s file. Error: %s", defaultResolvConfPath, err)
|
return fmt.Errorf("checking stats for %s file. Error: %w", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.originalPerms = stats.Mode()
|
f.originalPerms = stats.Mode()
|
||||||
|
|
||||||
err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
|
err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while backing up the %s file. Error: %s", defaultResolvConfPath, err)
|
return fmt.Errorf("backing up %s: %w", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) restore() error {
|
func (f *fileConfigurator) restore() error {
|
||||||
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while restoring the %s file from %s. Error: %s", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||||
|
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
||||||
|
resolvConf, err := parseDefaultResolvConf()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse current resolv.conf: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// no current nameservers set -> restore
|
||||||
|
if len(resolvConf.nameServers) == 0 {
|
||||||
|
return restoreResolvConfFile()
|
||||||
|
}
|
||||||
|
|
||||||
|
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
|
||||||
|
// not a valid first nameserver -> restore
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
|
||||||
|
return restoreResolvConfFile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// current address is still netbird's non-available dns address -> restore
|
||||||
|
// comparing parsed addresses only, to remove ambiguity
|
||||||
|
if currentDNSAddress.String() == storedDNSAddress.String() {
|
||||||
|
return restoreResolvConfFile()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func restoreResolvConfFile() error {
|
||||||
|
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)
|
||||||
|
|
||||||
|
if err := copyFile(fileUncleanShutdownResolvConfLocation, defaultResolvConfPath); err != nil {
|
||||||
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||||
|
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list
|
||||||
|
func generateNsList(nbNameserverIP string, cfg *resolvConf) []string {
|
||||||
|
ns := make([]string, 1, len(cfg.nameServers)+1)
|
||||||
|
ns[0] = nbNameserverIP
|
||||||
|
for _, cfgNs := range cfg.nameServers {
|
||||||
|
if nbNameserverIP != cfgNs {
|
||||||
|
ns = append(ns, cfgNs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
|
||||||
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
|
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
|
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
|
||||||
@@ -150,70 +245,6 @@ func searchDomains(config HostDNSConfig) []string {
|
|||||||
return listOfDomains
|
return listOfDomains
|
||||||
}
|
}
|
||||||
|
|
||||||
func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
|
|
||||||
file, err := os.Open(resolvconfFile)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf(`could not read existing resolv.conf`)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
reader := bufio.NewReader(file)
|
|
||||||
|
|
||||||
for {
|
|
||||||
lineBytes, isPrefix, readErr := reader.ReadLine()
|
|
||||||
if readErr != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if isPrefix {
|
|
||||||
err = fmt.Errorf(`resolv.conf line too long`)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
line := strings.TrimSpace(string(lineBytes))
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "domain") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
|
|
||||||
line = strings.ReplaceAll(line, "rotate", "")
|
|
||||||
splitLines := strings.Fields(line)
|
|
||||||
if len(splitLines) == 1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
line = strings.Join(splitLines, " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "search") {
|
|
||||||
splitLines := strings.Fields(line)
|
|
||||||
if len(splitLines) < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
searchDomains = splitLines[1:]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "nameserver") {
|
|
||||||
splitLines := strings.Fields(line)
|
|
||||||
if len(splitLines) != 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
nameServers = append(nameServers, splitLines[1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
others = append(others, line)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// merge search Domains lists and cut off the list if it is too long
|
// merge search Domains lists and cut off the list if it is too long
|
||||||
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
|
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
|
||||||
lineSize := len("search")
|
lineSize := len("search")
|
||||||
@@ -230,6 +261,19 @@ func mergeSearchDomains(searchDomains []string, originalSearchDomains []string)
|
|||||||
// return with the number of characters in the searchDomains line
|
// return with the number of characters in the searchDomains line
|
||||||
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
|
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
|
||||||
for _, sd := range vs {
|
for _, sd := range vs {
|
||||||
|
duplicated := false
|
||||||
|
for _, fs := range *s {
|
||||||
|
if fs == sd {
|
||||||
|
duplicated = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if duplicated {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
tmpCharsNumber := initialLineChars + 1 + len(sd)
|
tmpCharsNumber := initialLineChars + 1 + len(sd)
|
||||||
if tmpCharsNumber > fileMaxLineCharsLimit {
|
if tmpCharsNumber > fileMaxLineCharsLimit {
|
||||||
// lets log all skipped Domains
|
// lets log all skipped Domains
|
||||||
@@ -246,23 +290,39 @@ func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string
|
|||||||
}
|
}
|
||||||
*s = append(*s, sd)
|
*s = append(*s, sd)
|
||||||
}
|
}
|
||||||
|
|
||||||
return initialLineChars
|
return initialLineChars
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyFile(src, dest string) error {
|
func copyFile(src, dest string) error {
|
||||||
stats, err := os.Stat(src)
|
stats, err := os.Stat(src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an error while checking stats for %s file when copying it. Error: %s", src, err)
|
return fmt.Errorf("checking stats for %s file when copying it. Error: %s", src, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bytesRead, err := os.ReadFile(src)
|
bytesRead, err := os.ReadFile(src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an error while reading the file %s file for copy. Error: %s", src, err)
|
return fmt.Errorf("reading the file %s file for copy. Error: %s", src, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.WriteFile(dest, bytesRead, stats.Mode())
|
err = os.WriteFile(dest, bytesRead, stats.Mode())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an writing the destination file %s for copy. Error: %s", dest, err)
|
return fmt.Errorf("writing the destination file %s for copy. Error: %s", dest, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isContains(subList []string, list []string) bool {
|
||||||
|
for _, sl := range subList {
|
||||||
|
var found bool
|
||||||
|
for _, l := range list {
|
||||||
|
if sl == l {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,7 +9,7 @@ import (
|
|||||||
|
|
||||||
func Test_mergeSearchDomains(t *testing.T) {
|
func Test_mergeSearchDomains(t *testing.T) {
|
||||||
searchDomains := []string{"a", "b"}
|
searchDomains := []string{"a", "b"}
|
||||||
originDomains := []string{"a", "b"}
|
originDomains := []string{"c", "d"}
|
||||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||||
if len(mergedDomains) != 4 {
|
if len(mergedDomains) != 4 {
|
||||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
|
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
|
||||||
@@ -49,6 +51,67 @@ func Test_mergeSearchTooLongDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_isContains(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
subList []string
|
||||||
|
list []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
args: args{
|
||||||
|
subList: []string{"a", "b", "c"},
|
||||||
|
list: []string{"a", "b", "c"},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
args: args{
|
||||||
|
subList: []string{"a"},
|
||||||
|
list: []string{"a", "b", "c"},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
args: args{
|
||||||
|
subList: []string{"d"},
|
||||||
|
list: []string{"a", "b", "c"},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
args: args{
|
||||||
|
subList: []string{"a"},
|
||||||
|
list: []string{},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
args: args{
|
||||||
|
subList: []string{},
|
||||||
|
list: []string{"b"},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
args: args{
|
||||||
|
subList: []string{},
|
||||||
|
list: []string{},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run("list check test", func(t *testing.T) {
|
||||||
|
if got := isContains(tt.args.subList, tt.args.list); got != tt.want {
|
||||||
|
t.Errorf("isContains() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getLongLine() string {
|
func getLongLine() string {
|
||||||
x := "search "
|
x := "search "
|
||||||
for {
|
for {
|
||||||
|
|||||||
168
client/internal/dns/file_parser_linux.go
Normal file
168
client/internal/dns/file_parser_linux.go
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultResolvConfPath = "/etc/resolv.conf"
|
||||||
|
)
|
||||||
|
|
||||||
|
var timeoutRegex = regexp.MustCompile(`timeout:\d+`)
|
||||||
|
var attemptsRegex = regexp.MustCompile(`attempts:\d+`)
|
||||||
|
|
||||||
|
type resolvConf struct {
|
||||||
|
nameServers []string
|
||||||
|
searchDomains []string
|
||||||
|
others []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resolvConf) String() string {
|
||||||
|
return fmt.Sprintf("search domains: %v, name servers: %v, others: %s", r.searchDomains, r.nameServers, r.others)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDefaultResolvConf() (*resolvConf, error) {
|
||||||
|
return parseResolvConfFile(defaultResolvConfPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBackupResolvConf() (*resolvConf, error) {
|
||||||
|
return parseResolvConfFile(fileDefaultResolvConfBackupLocation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
||||||
|
rconf := &resolvConf{
|
||||||
|
searchDomains: make([]string, 0),
|
||||||
|
nameServers: make([]string, 0),
|
||||||
|
others: make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Open(resolvConfFile)
|
||||||
|
if err != nil {
|
||||||
|
return rconf, fmt.Errorf("failed to open %s file: %w", resolvConfFile, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := file.Close(); err != nil {
|
||||||
|
log.Errorf("failed closing %s: %s", resolvConfFile, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
cur, err := os.ReadFile(resolvConfFile)
|
||||||
|
if err != nil {
|
||||||
|
return rconf, fmt.Errorf("failed to read %s file: %w", resolvConfFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cur) == 0 {
|
||||||
|
return rconf, fmt.Errorf("file is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range strings.Split(string(cur), "\n") {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "domain") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
|
||||||
|
line = strings.ReplaceAll(line, "rotate", "")
|
||||||
|
splitLines := strings.Fields(line)
|
||||||
|
if len(splitLines) == 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
line = strings.Join(splitLines, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "search") {
|
||||||
|
splitLines := strings.Fields(line)
|
||||||
|
if len(splitLines) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rconf.searchDomains = splitLines[1:]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "nameserver") {
|
||||||
|
splitLines := strings.Fields(line)
|
||||||
|
if len(splitLines) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rconf.nameServers = append(rconf.nameServers, splitLines[1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if line != "" {
|
||||||
|
rconf.others = append(rconf.others, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rconf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist,
|
||||||
|
// otherwise it adds a new option with timeout and attempts.
|
||||||
|
func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string {
|
||||||
|
configs := make([]string, len(input))
|
||||||
|
copy(configs, input)
|
||||||
|
|
||||||
|
for i, config := range configs {
|
||||||
|
if strings.HasPrefix(config, "options") {
|
||||||
|
config = strings.ReplaceAll(config, "rotate", "")
|
||||||
|
config = strings.Join(strings.Fields(config), " ")
|
||||||
|
|
||||||
|
if strings.Contains(config, "timeout:") {
|
||||||
|
config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout))
|
||||||
|
} else {
|
||||||
|
config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(config, "attempts:") {
|
||||||
|
config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts))
|
||||||
|
} else {
|
||||||
|
config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
configs[i] = config
|
||||||
|
return configs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
|
||||||
|
// and writes the file back to the original location
|
||||||
|
func removeFirstNbNameserver(filename, nameserverIP string) error {
|
||||||
|
resolvConf, err := parseResolvConfFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse backup resolv.conf: %w", err)
|
||||||
|
}
|
||||||
|
content, err := os.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read %s: %w", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP {
|
||||||
|
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
|
||||||
|
|
||||||
|
stat, err := os.Stat(filename)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stat %s: %w", filename, err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
|
||||||
|
return fmt.Errorf("write %s: %w", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
304
client/internal/dns/file_parser_linux_test.go
Normal file
304
client/internal/dns/file_parser_linux_test.go
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_parseResolvConf(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
input string
|
||||||
|
expectedSearch []string
|
||||||
|
expectedNS []string
|
||||||
|
expectedOther []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
input: `domain example.org
|
||||||
|
search example.org
|
||||||
|
nameserver 192.168.0.1
|
||||||
|
`,
|
||||||
|
expectedSearch: []string{"example.org"},
|
||||||
|
expectedNS: []string{"192.168.0.1"},
|
||||||
|
expectedOther: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `# This is /run/systemd/resolve/resolv.conf managed by man:systemd-resolved(8).
|
||||||
|
# Do not edit.
|
||||||
|
#
|
||||||
|
# This file might be symlinked as /etc/resolv.conf. If you're looking at
|
||||||
|
# /etc/resolv.conf and seeing this text, you have followed the symlink.
|
||||||
|
#
|
||||||
|
# This is a dynamic resolv.conf file for connecting local clients directly to
|
||||||
|
# all known uplink DNS servers. This file lists all configured search domains.
|
||||||
|
#
|
||||||
|
# Third party programs should typically not access this file directly, but only
|
||||||
|
# through the symlink at /etc/resolv.conf. To manage man:resolv.conf(5) in a
|
||||||
|
# different way, replace this symlink by a static file or a different symlink.
|
||||||
|
#
|
||||||
|
# See man:systemd-resolved.service(8) for details about the supported modes of
|
||||||
|
# operation for /etc/resolv.conf.
|
||||||
|
|
||||||
|
nameserver 192.168.2.1
|
||||||
|
nameserver 100.81.99.197
|
||||||
|
search netbird.cloud
|
||||||
|
`,
|
||||||
|
expectedSearch: []string{"netbird.cloud"},
|
||||||
|
expectedNS: []string{"192.168.2.1", "100.81.99.197"},
|
||||||
|
expectedOther: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: `# This is /run/systemd/resolve/resolv.conf managed by man:systemd-resolved(8).
|
||||||
|
# Do not edit.
|
||||||
|
#
|
||||||
|
# This file might be symlinked as /etc/resolv.conf. If you're looking at
|
||||||
|
# /etc/resolv.conf and seeing this text, you have followed the symlink.
|
||||||
|
#
|
||||||
|
# This is a dynamic resolv.conf file for connecting local clients directly to
|
||||||
|
# all known uplink DNS servers. This file lists all configured search domains.
|
||||||
|
#
|
||||||
|
# Third party programs should typically not access this file directly, but only
|
||||||
|
# through the symlink at /etc/resolv.conf. To manage man:resolv.conf(5) in a
|
||||||
|
# different way, replace this symlink by a static file or a different symlink.
|
||||||
|
#
|
||||||
|
# See man:systemd-resolved.service(8) for details about the supported modes of
|
||||||
|
# operation for /etc/resolv.conf.
|
||||||
|
|
||||||
|
nameserver 192.168.2.1
|
||||||
|
nameserver 100.81.99.197
|
||||||
|
search netbird.cloud
|
||||||
|
options debug
|
||||||
|
`,
|
||||||
|
expectedSearch: []string{"netbird.cloud"},
|
||||||
|
expectedNS: []string{"192.168.2.1", "100.81.99.197"},
|
||||||
|
expectedOther: []string{"options debug"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run("test", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tmpResolvConf := filepath.Join(t.TempDir(), "resolv.conf")
|
||||||
|
err := os.WriteFile(tmpResolvConf, []byte(testCase.input), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cfg, err := parseResolvConfFile(tmpResolvConf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ok := compareLists(cfg.searchDomains, testCase.expectedSearch)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok = compareLists(cfg.nameServers, testCase.expectedNS)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok = compareLists(cfg.others, testCase.expectedOther)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("invalid parse result for others, expected: %v, got: %v", testCase.expectedOther, cfg.others)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareLists(search []string, search2 []string) bool {
|
||||||
|
if len(search) != len(search2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range search {
|
||||||
|
if v != search2[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_emptyFile(t *testing.T) {
|
||||||
|
cfg, err := parseResolvConfFile("/tmp/nothing")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error, got nil")
|
||||||
|
}
|
||||||
|
if len(cfg.others) != 0 || len(cfg.searchDomains) != 0 || len(cfg.nameServers) != 0 {
|
||||||
|
t.Errorf("expected empty config, got %v", cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_symlink(t *testing.T) {
|
||||||
|
input := `# This is /run/systemd/resolve/resolv.conf managed by man:systemd-resolved(8).
|
||||||
|
# Do not edit.
|
||||||
|
#
|
||||||
|
# This file might be symlinked as /etc/resolv.conf. If you're looking at
|
||||||
|
# /etc/resolv.conf and seeing this text, you have followed the symlink.
|
||||||
|
#
|
||||||
|
# This is a dynamic resolv.conf file for connecting local clients directly to
|
||||||
|
# all known uplink DNS servers. This file lists all configured search domains.
|
||||||
|
#
|
||||||
|
# Third party programs should typically not access this file directly, but only
|
||||||
|
# through the symlink at /etc/resolv.conf. To manage man:resolv.conf(5) in a
|
||||||
|
# different way, replace this symlink by a static file or a different symlink.
|
||||||
|
#
|
||||||
|
# See man:systemd-resolved.service(8) for details about the supported modes of
|
||||||
|
# operation for /etc/resolv.conf.
|
||||||
|
|
||||||
|
nameserver 192.168.0.1
|
||||||
|
`
|
||||||
|
|
||||||
|
tmpResolvConf := filepath.Join(t.TempDir(), "resolv.conf")
|
||||||
|
err := os.WriteFile(tmpResolvConf, []byte(input), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpLink := filepath.Join(t.TempDir(), "symlink")
|
||||||
|
err = os.Symlink(tmpResolvConf, tmpLink)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := parseResolvConfFile(tmpLink)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.nameServers) != 1 {
|
||||||
|
t.Errorf("unexpected resolv.conf content: %v", cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareOptionsWithTimeout(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
others []string
|
||||||
|
timeout int
|
||||||
|
attempts int
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Append new options with timeout and attempts",
|
||||||
|
others: []string{"some config"},
|
||||||
|
timeout: 2,
|
||||||
|
attempts: 2,
|
||||||
|
expected: []string{"some config", "options timeout:2 attempts:2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Modify existing options to exclude rotate and include timeout and attempts",
|
||||||
|
others: []string{"some config", "options rotate someother"},
|
||||||
|
timeout: 3,
|
||||||
|
attempts: 2,
|
||||||
|
expected: []string{"some config", "options attempts:2 timeout:3 someother"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Existing options with timeout and attempts are updated",
|
||||||
|
others: []string{"some config", "options timeout:4 attempts:3"},
|
||||||
|
timeout: 5,
|
||||||
|
attempts: 4,
|
||||||
|
expected: []string{"some config", "options timeout:5 attempts:4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Modify existing options, add missing attempts before timeout",
|
||||||
|
others: []string{"some config", "options timeout:4"},
|
||||||
|
timeout: 4,
|
||||||
|
attempts: 3,
|
||||||
|
expected: []string{"some config", "options attempts:3 timeout:4"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts)
|
||||||
|
assert.Equal(t, tc.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveFirstNbNameserver(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
ipToRemove string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Unrelated nameservers with comments and options",
|
||||||
|
content: `# This is a comment
|
||||||
|
options rotate
|
||||||
|
nameserver 1.1.1.1
|
||||||
|
# Another comment
|
||||||
|
nameserver 8.8.4.4
|
||||||
|
search example.com`,
|
||||||
|
ipToRemove: "9.9.9.9",
|
||||||
|
expected: `# This is a comment
|
||||||
|
options rotate
|
||||||
|
nameserver 1.1.1.1
|
||||||
|
# Another comment
|
||||||
|
nameserver 8.8.4.4
|
||||||
|
search example.com`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "First nameserver matches",
|
||||||
|
content: `search example.com
|
||||||
|
nameserver 9.9.9.9
|
||||||
|
# oof, a comment
|
||||||
|
nameserver 8.8.4.4
|
||||||
|
options attempts:5`,
|
||||||
|
ipToRemove: "9.9.9.9",
|
||||||
|
expected: `search example.com
|
||||||
|
# oof, a comment
|
||||||
|
nameserver 8.8.4.4
|
||||||
|
options attempts:5`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Target IP not the first nameserver",
|
||||||
|
// nolint:dupword
|
||||||
|
content: `# Comment about the first nameserver
|
||||||
|
nameserver 8.8.4.4
|
||||||
|
# Comment before our target
|
||||||
|
nameserver 9.9.9.9
|
||||||
|
options timeout:2`,
|
||||||
|
ipToRemove: "9.9.9.9",
|
||||||
|
// nolint:dupword
|
||||||
|
expected: `# Comment about the first nameserver
|
||||||
|
nameserver 8.8.4.4
|
||||||
|
# Comment before our target
|
||||||
|
nameserver 9.9.9.9
|
||||||
|
options timeout:2`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Only nameserver matches",
|
||||||
|
content: `options debug
|
||||||
|
nameserver 9.9.9.9
|
||||||
|
search localdomain`,
|
||||||
|
ipToRemove: "9.9.9.9",
|
||||||
|
expected: `options debug
|
||||||
|
nameserver 9.9.9.9
|
||||||
|
search localdomain`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
tempFile := filepath.Join(tempDir, "resolv.conf")
|
||||||
|
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = removeFirstNbNameserver(tempFile, tc.ipToRemove)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
content, err := os.ReadFile(tempFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
159
client/internal/dns/file_repair_linux.go
Normal file
159
client/internal/dns/file_repair_linux.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
eventTypes = []fsnotify.Op{
|
||||||
|
fsnotify.Create,
|
||||||
|
fsnotify.Write,
|
||||||
|
fsnotify.Remove,
|
||||||
|
fsnotify.Rename,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type repairConfFn func([]string, string, *resolvConf) error
|
||||||
|
|
||||||
|
type repair struct {
|
||||||
|
operationFile string
|
||||||
|
updateFn repairConfFn
|
||||||
|
watchDir string
|
||||||
|
|
||||||
|
inotify *fsnotify.Watcher
|
||||||
|
inotifyWg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRepair(operationFile string, updateFn repairConfFn) *repair {
|
||||||
|
targetFile := targetFile(operationFile)
|
||||||
|
return &repair{
|
||||||
|
operationFile: targetFile,
|
||||||
|
watchDir: path.Dir(targetFile),
|
||||||
|
updateFn: updateFn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) {
|
||||||
|
if f.inotify != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("start to watch resolv.conf: %s", f.operationFile)
|
||||||
|
inotify, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to start inotify watcher for resolv.conf: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.inotify = inotify
|
||||||
|
|
||||||
|
f.inotifyWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer f.inotifyWg.Done()
|
||||||
|
for event := range f.inotify.Events {
|
||||||
|
if !f.isEventRelevant(event) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("%s changed, check if it is broken", f.operationFile)
|
||||||
|
|
||||||
|
rConf, err := parseResolvConfFile(f.operationFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse resolv conf: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("check resolv.conf parameters: %s", rConf)
|
||||||
|
if !isNbParamsMissing(nbSearchDomains, nbNameserverIP, rConf) {
|
||||||
|
log.Tracef("resolv.conf still correct, skip the update")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Info("broken params in resolv.conf, repairing it...")
|
||||||
|
|
||||||
|
err = f.inotify.Remove(f.watchDir)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to repair resolv.conf: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = f.inotify.Add(f.watchDir)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to re-add inotify watch for resolv.conf: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = f.inotify.Add(f.watchDir)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add inotify watch for resolv.conf: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *repair) stopWatchFileChanges() {
|
||||||
|
if f.inotify == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := f.inotify.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to close resolv.conf inotify: %v", err)
|
||||||
|
}
|
||||||
|
f.inotifyWg.Wait()
|
||||||
|
f.inotify = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *repair) isEventRelevant(event fsnotify.Event) bool {
|
||||||
|
var ok bool
|
||||||
|
for _, et := range eventTypes {
|
||||||
|
if event.Has(et) {
|
||||||
|
ok = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Name == f.operationFile {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// nbParamsAreMissing checks if the resolv.conf file contains all the parameters that NetBird needs
|
||||||
|
// check the NetBird related nameserver IP at the first place
|
||||||
|
// check the NetBird related search domains in the search domains list
|
||||||
|
func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP string, rConf *resolvConf) bool {
|
||||||
|
if !isContains(nbSearchDomains, rConf.searchDomains) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rConf.nameServers) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if rConf.nameServers[0] != nbNameserverIP {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func targetFile(filename string) string {
|
||||||
|
target, err := filepath.EvalSymlinks(filename)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("evarl err: %s", err)
|
||||||
|
}
|
||||||
|
return target
|
||||||
|
}
|
||||||
175
client/internal/dns/file_repair_linux_test.go
Normal file
175
client/internal/dns/file_repair_linux_test.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
_ = util.InitLog("debug", "console")
|
||||||
|
code := m.Run()
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_newRepairtmp(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
resolvConfContent string
|
||||||
|
touchedConfContent string
|
||||||
|
wantChange bool
|
||||||
|
}
|
||||||
|
tests := []args{
|
||||||
|
{
|
||||||
|
resolvConfContent: `
|
||||||
|
nameserver 10.0.0.1
|
||||||
|
nameserver 8.8.8.8
|
||||||
|
searchdomain netbird.cloud something`,
|
||||||
|
|
||||||
|
touchedConfContent: `
|
||||||
|
nameserver 8.8.8.8
|
||||||
|
searchdomain netbird.cloud something`,
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
resolvConfContent: `
|
||||||
|
nameserver 10.0.0.1
|
||||||
|
nameserver 8.8.8.8
|
||||||
|
searchdomain netbird.cloud something`,
|
||||||
|
|
||||||
|
touchedConfContent: `
|
||||||
|
nameserver 10.0.0.1
|
||||||
|
nameserver 8.8.8.8
|
||||||
|
searchdomain netbird.cloud something somethingelse`,
|
||||||
|
wantChange: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
resolvConfContent: `
|
||||||
|
nameserver 10.0.0.1
|
||||||
|
nameserver 8.8.8.8
|
||||||
|
searchdomain netbird.cloud something`,
|
||||||
|
|
||||||
|
touchedConfContent: `
|
||||||
|
nameserver 10.0.0.1
|
||||||
|
searchdomain netbird.cloud something`,
|
||||||
|
wantChange: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
resolvConfContent: `
|
||||||
|
nameserver 10.0.0.1
|
||||||
|
nameserver 8.8.8.8
|
||||||
|
searchdomain netbird.cloud something`,
|
||||||
|
|
||||||
|
touchedConfContent: `
|
||||||
|
searchdomain something`,
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
resolvConfContent: `
|
||||||
|
nameserver 10.0.0.1
|
||||||
|
nameserver 8.8.8.8
|
||||||
|
searchdomain netbird.cloud something`,
|
||||||
|
|
||||||
|
touchedConfContent: `
|
||||||
|
nameserver 10.0.0.1`,
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
resolvConfContent: `
|
||||||
|
nameserver 10.0.0.1
|
||||||
|
nameserver 8.8.8.8
|
||||||
|
searchdomain netbird.cloud something`,
|
||||||
|
|
||||||
|
touchedConfContent: `
|
||||||
|
nameserver 8.8.8.8`,
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run("test", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
workDir := t.TempDir()
|
||||||
|
operationFile := workDir + "/resolv.conf"
|
||||||
|
err := os.WriteFile(operationFile, []byte(tt.resolvConfContent), 0755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write out resolv.conf: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var changed bool
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
updateFn := func([]string, string, *resolvConf) error {
|
||||||
|
changed = true
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r := newRepair(operationFile, updateFn)
|
||||||
|
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
||||||
|
|
||||||
|
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write out resolv.conf: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
r.stopWatchFileChanges()
|
||||||
|
|
||||||
|
if changed != tt.wantChange {
|
||||||
|
t.Errorf("unexpected result: want: %v, got: %v", tt.wantChange, changed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_newRepairSymlink(t *testing.T) {
|
||||||
|
resolvConfContent := `
|
||||||
|
nameserver 10.0.0.1
|
||||||
|
nameserver 8.8.8.8
|
||||||
|
searchdomain netbird.cloud something`
|
||||||
|
|
||||||
|
modifyContent := `nameserver 8.8.8.8`
|
||||||
|
|
||||||
|
tmpResolvConf := filepath.Join(t.TempDir(), "resolv.conf")
|
||||||
|
err := os.WriteFile(tmpResolvConf, []byte(resolvConfContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpLink := filepath.Join(t.TempDir(), "symlink")
|
||||||
|
err = os.Symlink(tmpResolvConf, tmpLink)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var changed bool
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
updateFn := func([]string, string, *resolvConf) error {
|
||||||
|
changed = true
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r := newRepair(tmpLink, updateFn)
|
||||||
|
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
||||||
|
|
||||||
|
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write out resolv.conf: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
r.stopWatchFileChanges()
|
||||||
|
|
||||||
|
if changed != true {
|
||||||
|
t.Errorf("unexpected result: want: %v, got: %v", true, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -11,6 +12,7 @@ type hostManager interface {
|
|||||||
applyDNSConfig(config HostDNSConfig) error
|
applyDNSConfig(config HostDNSConfig) error
|
||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
supportCustomPort() bool
|
supportCustomPort() bool
|
||||||
|
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type HostDNSConfig struct {
|
type HostDNSConfig struct {
|
||||||
@@ -27,9 +29,10 @@ type DomainConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockHostConfigurator struct {
|
type mockHostConfigurator struct {
|
||||||
applyDNSConfigFunc func(config HostDNSConfig) error
|
applyDNSConfigFunc func(config HostDNSConfig) error
|
||||||
restoreHostDNSFunc func() error
|
restoreHostDNSFunc func() error
|
||||||
supportCustomPortFunc func() bool
|
supportCustomPortFunc func() bool
|
||||||
|
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||||
@@ -53,11 +56,19 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
||||||
|
if m.restoreUncleanShutdownDNSFunc != nil {
|
||||||
|
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func newNoopHostMocker() hostManager {
|
func newNoopHostMocker() hostManager {
|
||||||
return &mockHostConfigurator{
|
return &mockHostConfigurator{
|
||||||
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
|
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
|
||||||
restoreHostDNSFunc: func() error { return nil },
|
restoreHostDNSFunc: func() error { return nil },
|
||||||
supportCustomPortFunc: func() bool { return true },
|
supportCustomPortFunc: func() bool { return true },
|
||||||
|
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
type androidHostManager struct {
|
type androidHostManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
func newHostManager() (hostManager, error) {
|
||||||
return &androidHostManager{}, nil
|
return &androidHostManager{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -18,3 +20,7 @@ func (a androidHostManager) restoreHostDNS() error {
|
|||||||
func (a androidHostManager) supportCustomPort() bool {
|
func (a androidHostManager) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -34,7 +36,7 @@ type systemConfigurator struct {
|
|||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(_ WGIface) (hostManager, error) {
|
func newHostManager() (hostManager, error) {
|
||||||
return &systemConfigurator{
|
return &systemConfigurator{
|
||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
@@ -50,17 +52,22 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
|
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("add dns setup for all: %w", err)
|
||||||
}
|
}
|
||||||
} else if s.primaryServiceID != "" {
|
} else if s.primaryServiceID != "" {
|
||||||
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
|
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("remote key from system config: %w", err)
|
||||||
}
|
}
|
||||||
s.primaryServiceID = ""
|
s.primaryServiceID = ""
|
||||||
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
|
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a file for unclean shutdown detection
|
||||||
|
if err := createUncleanShutdownIndicator(); err != nil {
|
||||||
|
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
matchDomains []string
|
matchDomains []string
|
||||||
@@ -85,7 +92,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
err = s.removeKeyFromSystemConfig(matchKey)
|
err = s.removeKeyFromSystemConfig(matchKey)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("add match domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
@@ -96,7 +103,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
err = s.removeKeyFromSystemConfig(searchKey)
|
err = s.removeKeyFromSystemConfig(searchKey)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -119,7 +126,11 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
|||||||
_, err := runSystemConfigCommand(wrapCommand(lines))
|
_, err := runSystemConfigCommand(wrapCommand(lines))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("got an error while cleaning the system configuration: %s", err)
|
log.Errorf("got an error while cleaning the system configuration: %s", err)
|
||||||
return err
|
return fmt.Errorf("clean system: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||||
|
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -129,7 +140,7 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
|||||||
line := buildRemoveKeyOperation(key)
|
line := buildRemoveKeyOperation(key)
|
||||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("remove key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(s.createdKeys, key)
|
delete(s.createdKeys, key)
|
||||||
@@ -140,7 +151,7 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
|||||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
|
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
|
||||||
err := s.addDNSState(key, domains, ip, port, true)
|
err := s.addDNSState(key, domains, ip, port, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("add dns state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||||
@@ -153,7 +164,7 @@ func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, po
|
|||||||
func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error {
|
func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error {
|
||||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
err := s.addDNSState(key, domains, dnsServer, port, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("add dns state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||||
@@ -178,33 +189,37 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
|
|||||||
|
|
||||||
_, err := runSystemConfigCommand(stdinCommands)
|
_, err := runSystemConfigCommand(stdinCommands)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while applying state for domains %s, error: %s", domains, err)
|
return fmt.Errorf("applying state for domains %s, error: %w", domains, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
||||||
primaryServiceKey, existingNameserver := s.getPrimaryService()
|
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
|
||||||
if primaryServiceKey == "" {
|
if err != nil || primaryServiceKey == "" {
|
||||||
return fmt.Errorf("couldn't find the primary service key")
|
return fmt.Errorf("couldn't find the primary service key: %w", err)
|
||||||
}
|
}
|
||||||
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
|
||||||
|
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("add dns setup: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
|
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
|
||||||
s.primaryServiceID = primaryServiceKey
|
s.primaryServiceID = primaryServiceKey
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) getPrimaryService() (string, string) {
|
func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
||||||
line := buildCommandLine("show", globalIPv4State, "")
|
line := buildCommandLine("show", globalIPv4State, "")
|
||||||
stdinCommands := wrapCommand(line)
|
stdinCommands := wrapCommand(line)
|
||||||
|
|
||||||
b, err := runSystemConfigCommand(stdinCommands)
|
b, err := runSystemConfigCommand(stdinCommands)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("got error while sending the command: ", err)
|
return "", "", fmt.Errorf("sending the command: %w", err)
|
||||||
return "", ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(bytes.NewReader(b))
|
scanner := bufio.NewScanner(bytes.NewReader(b))
|
||||||
primaryService := ""
|
primaryService := ""
|
||||||
router := ""
|
router := ""
|
||||||
@@ -217,7 +232,11 @@ func (s *systemConfigurator) getPrimaryService() (string, string) {
|
|||||||
router = strings.TrimSpace(strings.Split(text, ":")[1])
|
router = strings.TrimSpace(strings.Split(text, ":")[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return primaryService, router
|
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||||
|
return primaryService, router, fmt.Errorf("scan: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return primaryService, router, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
||||||
@@ -228,7 +247,14 @@ func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, e
|
|||||||
stdinCommands := wrapCommand(addDomainCommand)
|
stdinCommands := wrapCommand(addDomainCommand)
|
||||||
_, err := runSystemConfigCommand(stdinCommands)
|
_, err := runSystemConfigCommand(stdinCommands)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while applying dns setup, error: %s", err)
|
return fmt.Errorf("applying dns setup, error: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||||
|
if err := s.restoreHostDNS(); err != nil {
|
||||||
|
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -266,7 +292,7 @@ func runSystemConfigCommand(command string) ([]byte, error) {
|
|||||||
cmd.Stdin = strings.NewReader(command)
|
cmd.Stdin = strings.NewReader(command)
|
||||||
out, err := cmd.Output()
|
out, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("got error while running system configuration command: \"%s\", error: %s", command, err)
|
return nil, fmt.Errorf("running system configuration command: \"%s\", error: %w", command, err)
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@@ -20,7 +22,7 @@ func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
|
|||||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
|
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
|
||||||
jsonData, err := json.Marshal(config)
|
jsonData, err := json.Marshal(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("marshal: %w", err)
|
||||||
}
|
}
|
||||||
jsonString := string(jsonData)
|
jsonString := string(jsonData)
|
||||||
log.Debugf("Applying DNS settings: %s", jsonString)
|
log.Debugf("Applying DNS settings: %s", jsonString)
|
||||||
@@ -35,3 +37,7 @@ func (a iosHostManager) restoreHostDNS() error {
|
|||||||
func (a iosHostManager) supportCustomPort() bool {
|
func (a iosHostManager) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,17 +4,15 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
defaultResolvConfPath = "/etc/resolv.conf"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
netbirdManager osManagerType = iota
|
netbirdManager osManagerType = iota
|
||||||
fileManager
|
fileManager
|
||||||
@@ -23,8 +21,27 @@ const (
|
|||||||
resolvConfManager
|
resolvConfManager
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
|
||||||
|
|
||||||
type osManagerType int
|
type osManagerType int
|
||||||
|
|
||||||
|
func newOsManagerType(osManager string) (osManagerType, error) {
|
||||||
|
switch osManager {
|
||||||
|
case "netbird":
|
||||||
|
return fileManager, nil
|
||||||
|
case "file":
|
||||||
|
return netbirdManager, nil
|
||||||
|
case "networkManager":
|
||||||
|
return networkManager, nil
|
||||||
|
case "systemd":
|
||||||
|
return systemdManager, nil
|
||||||
|
case "resolvconf":
|
||||||
|
return resolvConfManager, nil
|
||||||
|
default:
|
||||||
|
return 0, ErrUnknownOsManagerType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t osManagerType) String() string {
|
func (t osManagerType) String() string {
|
||||||
switch t {
|
switch t {
|
||||||
case netbirdManager:
|
case netbirdManager:
|
||||||
@@ -42,13 +59,17 @@ func (t osManagerType) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("discovered mode is: %s", osManager)
|
log.Infof("System DNS manager discovered: %s", osManager)
|
||||||
|
return newHostManagerFromType(wgInterface, osManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) {
|
||||||
switch osManager {
|
switch osManager {
|
||||||
case networkManager:
|
case networkManager:
|
||||||
return newNetworkManagerDbusConfigurator(wgInterface)
|
return newNetworkManagerDbusConfigurator(wgInterface)
|
||||||
@@ -62,12 +83,15 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getOSDNSManagerType() (osManagerType, error) {
|
func getOSDNSManagerType() (osManagerType, error) {
|
||||||
|
|
||||||
file, err := os.Open(defaultResolvConfPath)
|
file, err := os.Open(defaultResolvConfPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %s", defaultResolvConfPath, err)
|
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer func() {
|
||||||
|
if err := file.Close(); err != nil {
|
||||||
|
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
@@ -85,7 +109,11 @@ func getOSDNSManagerType() (osManagerType, error) {
|
|||||||
return networkManager, nil
|
return networkManager, nil
|
||||||
}
|
}
|
||||||
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||||
return systemdManager, nil
|
if checkStub() {
|
||||||
|
return systemdManager, nil
|
||||||
|
} else {
|
||||||
|
return fileManager, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if strings.Contains(text, "resolvconf") {
|
if strings.Contains(text, "resolvconf") {
|
||||||
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||||
@@ -101,5 +129,26 @@ func getOSDNSManagerType() (osManagerType, error) {
|
|||||||
return resolvConfManager, nil
|
return resolvConfManager, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||||
|
return 0, fmt.Errorf("scan: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return fileManager, nil
|
return fileManager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||||
|
func checkStub() bool {
|
||||||
|
rConf, err := parseDefaultResolvConf()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse resolv conf: %s", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ns := range rConf.nameServers {
|
||||||
|
if ns == "127.0.0.53" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -9,7 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dnsPolicyConfigMatchPath = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig\\NetBird-Match"
|
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
||||||
dnsPolicyConfigVersionKey = "Version"
|
dnsPolicyConfigVersionKey = "Version"
|
||||||
dnsPolicyConfigVersionValue = 2
|
dnsPolicyConfigVersionValue = 2
|
||||||
dnsPolicyConfigNameKey = "Name"
|
dnsPolicyConfigNameKey = "Name"
|
||||||
@@ -19,7 +21,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
|
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
)
|
)
|
||||||
@@ -34,12 +36,16 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return newHostManagerWithGuid(guid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHostManagerWithGuid(guid string) (hostManager, error) {
|
||||||
return ®istryConfigurator{
|
return ®istryConfigurator{
|
||||||
guid: guid,
|
guid: guid,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *registryConfigurator) supportCustomPort() bool {
|
func (r *registryConfigurator) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,17 +54,22 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
err = r.addDNSSetupForAll(config.ServerIP)
|
err = r.addDNSSetupForAll(config.ServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("add dns setup: %w", err)
|
||||||
}
|
}
|
||||||
} else if r.routingAll {
|
} else if r.routingAll {
|
||||||
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
|
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||||
}
|
}
|
||||||
r.routingAll = false
|
r.routingAll = false
|
||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a file for unclean shutdown detection
|
||||||
|
if err := createUncleanShutdownIndicator(r.guid); err != nil {
|
||||||
|
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
matchDomains []string
|
matchDomains []string
|
||||||
@@ -80,12 +91,12 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("add dns match policy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.updateSearchDomains(searchDomains)
|
err = r.updateSearchDomains(searchDomains)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -94,7 +105,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
||||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
|
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("adding dns setup for all failed with error: %s", err)
|
return fmt.Errorf("adding dns setup for all failed with error: %w", err)
|
||||||
}
|
}
|
||||||
r.routingAll = true
|
r.routingAll = true
|
||||||
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
||||||
@@ -106,33 +117,33 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
|
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
|
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
|
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
|
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
|
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigVersionKey, err)
|
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigVersionKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
|
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigNameKey, err)
|
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigNameKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
|
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigGenericDNSServersKey, err)
|
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigGenericDNSServersKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
|
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigConfigOptionsKey, err)
|
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigConfigOptionsKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
|
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
|
||||||
@@ -141,18 +152,25 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) restoreHostDNS() error {
|
func (r *registryConfigurator) restoreHostDNS() error {
|
||||||
err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||||
if err != nil {
|
log.Errorf("remove registry key from dns policy config: %s", err)
|
||||||
log.Error(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey)
|
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
|
||||||
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||||
|
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
|
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("adding search domain failed with error: %s", err)
|
return fmt.Errorf("adding search domain failed with error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
|
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
|
||||||
@@ -163,13 +181,13 @@ func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
|||||||
func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error {
|
func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error {
|
||||||
regKey, err := r.getInterfaceRegistryKey()
|
regKey, err := r.getInterfaceRegistryKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("get interface registry key: %w", err)
|
||||||
}
|
}
|
||||||
defer regKey.Close()
|
defer closer(regKey)
|
||||||
|
|
||||||
err = regKey.SetStringValue(key, value)
|
err = regKey.SetStringValue(key, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %s", key, value, err)
|
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %w", key, value, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -178,13 +196,13 @@ func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value str
|
|||||||
func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error {
|
func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error {
|
||||||
regKey, err := r.getInterfaceRegistryKey()
|
regKey, err := r.getInterfaceRegistryKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("get interface registry key: %w", err)
|
||||||
}
|
}
|
||||||
defer regKey.Close()
|
defer closer(regKey)
|
||||||
|
|
||||||
err = regKey.DeleteValue(propertyKey)
|
err = regKey.DeleteValue(propertyKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("deleting registry key %s for interface failed with error: %s", propertyKey, err)
|
return fmt.Errorf("deleting registry key %s for interface failed with error: %w", propertyKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -197,20 +215,33 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
|||||||
|
|
||||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
|
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return regKey, nil
|
return regKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||||
|
if err := r.restoreHostDNS(); err != nil {
|
||||||
|
return fmt.Errorf("restoring dns via registry: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
||||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
k.Close()
|
defer closer(k)
|
||||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
|
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
|
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func closer(closer io.Closer) {
|
||||||
|
if err := closer.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
|
|||||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
|
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||||
fullRecord, err := dns.NewRR(record.String())
|
fullRecord, err := dns.NewRR(record.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("register record: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fullRecord.Header().Rdlength = record.Len()
|
fullRecord.Header().Rdlength = record.Len()
|
||||||
@@ -71,3 +71,5 @@ func buildRecordKey(name string, class, qType uint16) string {
|
|||||||
key := fmt.Sprintf("%s_%d_%d", name, class, qType)
|
key := fmt.Sprintf("%s_%d_%d", name, class, qType)
|
||||||
return key
|
return key
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *localResolver) probeAvailability() {}
|
||||||
|
|||||||
@@ -48,3 +48,7 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
|||||||
func (m *MockServer) SearchDomains() []string {
|
func (m *MockServer) SearchDomains() []string {
|
||||||
return make([]string, 0)
|
return make([]string, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||||
|
func (m *MockServer) ProbeAvailability() {
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
@@ -41,9 +43,13 @@ const (
|
|||||||
networkManagerDbusPrimaryDNSPriority int32 = -500
|
networkManagerDbusPrimaryDNSPriority int32 = -500
|
||||||
networkManagerDbusWithMatchDomainPriority int32 = 0
|
networkManagerDbusWithMatchDomainPriority int32 = 0
|
||||||
networkManagerDbusSearchDomainOnlyPriority int32 = 50
|
networkManagerDbusSearchDomainOnlyPriority int32 = 50
|
||||||
supportedNetworkManagerVersionConstraint = ">= 1.16, < 1.28"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var supportedNetworkManagerVersionConstraints = []string{
|
||||||
|
">= 1.16, < 1.27",
|
||||||
|
">= 1.44, < 1.45",
|
||||||
|
}
|
||||||
|
|
||||||
type networkManagerDbusConfigurator struct {
|
type networkManagerDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
routingAll bool
|
routingAll bool
|
||||||
@@ -71,19 +77,19 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("get nm dbus: %w", err)
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
var s string
|
var s string
|
||||||
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.Name()).Store(&s)
|
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface).Store(&s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("call: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.Name())
|
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
||||||
|
|
||||||
return &networkManagerDbusConfigurator{
|
return &networkManagerDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
@@ -97,14 +103,14 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
|||||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||||
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err)
|
return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connSettings.cleanDeprecatedSettings()
|
connSettings.cleanDeprecatedSettings()
|
||||||
|
|
||||||
dnsIP, err := netip.ParseAddr(config.ServerIP)
|
dnsIP, err := netip.ParseAddr(config.ServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to parse ip address, error: %s", err)
|
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
||||||
}
|
}
|
||||||
convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice())
|
convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice())
|
||||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
|
||||||
@@ -145,23 +151,37 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
|
|||||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
||||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
||||||
|
|
||||||
|
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
||||||
|
// The file content itself is not important for network-manager restoration
|
||||||
|
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil {
|
||||||
|
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||||
err = n.reApplyConnectionSettings(connSettings, configVersion)
|
err = n.reApplyConnectionSettings(connSettings, configVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an error while reapplying the connection with new settings, error: %s", err)
|
return fmt.Errorf("reapplying the connection with new settings, error: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
|
func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
|
||||||
// once the interface is gone network manager cleans all config associated with it
|
// once the interface is gone network manager cleans all config associated with it
|
||||||
return n.deleteConnectionSettings()
|
if err := n.deleteConnectionSettings(); err != nil {
|
||||||
|
return fmt.Errorf("delete connection settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||||
|
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) {
|
func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
|
return nil, 0, fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
|
|
||||||
@@ -176,7 +196,7 @@ func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (network
|
|||||||
err = obj.CallWithContext(ctx, networkManagerDbusDeviceGetAppliedConnectionMethod, dbusDefaultFlag,
|
err = obj.CallWithContext(ctx, networkManagerDbusDeviceGetAppliedConnectionMethod, dbusDefaultFlag,
|
||||||
networkManagerDbusDefaultBehaviorFlag).Store(&connSettings, &configVersion)
|
networkManagerDbusDefaultBehaviorFlag).Store(&connSettings, &configVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("got error while calling GetAppliedConnection method with context, err: %s", err)
|
return nil, 0, fmt.Errorf("calling GetAppliedConnection method with context, err: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return connSettings, configVersion, nil
|
return connSettings, configVersion, nil
|
||||||
@@ -185,7 +205,7 @@ func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (network
|
|||||||
func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings networkManagerConnSettings, configVersion networkManagerConfigVersion) error {
|
func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings networkManagerConnSettings, configVersion networkManagerConfigVersion) error {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
|
return fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
|
|
||||||
@@ -195,7 +215,7 @@ func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings
|
|||||||
err = obj.CallWithContext(ctx, networkManagerDbusDeviceReapplyMethod, dbusDefaultFlag,
|
err = obj.CallWithContext(ctx, networkManagerDbusDeviceReapplyMethod, dbusDefaultFlag,
|
||||||
connSettings, configVersion, networkManagerDbusDefaultBehaviorFlag).Store()
|
connSettings, configVersion, networkManagerDbusDefaultBehaviorFlag).Store()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while calling ReApply method with context, err: %s", err)
|
return fmt.Errorf("calling ReApply method with context, err: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -204,21 +224,34 @@ func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings
|
|||||||
func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
|
func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
|
return fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
// this call is required to remove the device for DNS cleanup, even if it fails
|
||||||
err = obj.CallWithContext(ctx, networkManagerDbusDeviceDeleteMethod, dbusDefaultFlag).Store()
|
err = obj.CallWithContext(ctx, networkManagerDbusDeviceDeleteMethod, dbusDefaultFlag).Store()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while calling delete method with context, err: %s", err)
|
var dbusErr dbus.Error
|
||||||
|
if errors.As(err, &dbusErr) && dbusErr.Name == dbus.ErrMsgUnknownMethod.Name {
|
||||||
|
// interface is gone already
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("calling delete method with context, err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||||
|
if err := n.restoreHostDNS(); err != nil {
|
||||||
|
return fmt.Errorf("restoring dns via network-manager: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func isNetworkManagerSupported() bool {
|
func isNetworkManagerSupported() bool {
|
||||||
return isNetworkManagerSupportedVersion() && isNetworkManagerSupportedMode()
|
return isNetworkManagerSupportedVersion() && isNetworkManagerSupportedMode()
|
||||||
}
|
}
|
||||||
@@ -250,13 +283,13 @@ func isNetworkManagerSupportedMode() bool {
|
|||||||
func getNetworkManagerDNSProperty(property string, store any) error {
|
func getNetworkManagerDNSProperty(property string, store any) error {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while attempting to retrieve the network manager dns manager object, error: %s", err)
|
return fmt.Errorf("attempting to retrieve the network manager dns manager object, error: %w", err)
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
|
|
||||||
v, e := obj.GetProperty(property)
|
v, e := obj.GetProperty(property)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return fmt.Errorf("got an error getting property %s: %v", property, e)
|
return fmt.Errorf("getting property %s: %w", property, e)
|
||||||
}
|
}
|
||||||
|
|
||||||
return v.Store(store)
|
return v.Store(store)
|
||||||
@@ -278,15 +311,26 @@ func isNetworkManagerSupportedVersion() bool {
|
|||||||
}
|
}
|
||||||
versionValue, err := parseVersion(value.Value().(string))
|
versionValue, err := parseVersion(value.Value().(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Errorf("nm: parse version: %s", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
constraints, err := version.NewConstraint(supportedNetworkManagerVersionConstraint)
|
var supported bool
|
||||||
if err != nil {
|
for _, constraint := range supportedNetworkManagerVersionConstraints {
|
||||||
return false
|
constr, err := version.NewConstraint(constraint)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("nm: create constraint: %s", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if met := constr.Check(versionValue); met {
|
||||||
|
supported = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return constraints.Check(versionValue)
|
log.Debugf("network manager constraints [%s] met: %t", strings.Join(supportedNetworkManagerVersionConstraints, " | "), supported)
|
||||||
|
return supported
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseVersion(inputVersion string) (*version.Version, error) {
|
func parseVersion(inputVersion string) (*version.Version, error) {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -21,17 +22,17 @@ type resolvconf struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// supported "openresolv" only
|
// supported "openresolv" only
|
||||||
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
func newResolvConfConfigurator(wgInterface string) (hostManager, error) {
|
||||||
originalSearchDomains, nameServers, others, err := originalDNSConfigs("/etc/resolv.conf")
|
resolvConfEntries, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &resolvconf{
|
return &resolvconf{
|
||||||
ifaceName: wgInterface.Name(),
|
ifaceName: wgInterface,
|
||||||
originalSearchDomains: originalSearchDomains,
|
originalSearchDomains: resolvConfEntries.searchDomains,
|
||||||
originalNameServers: nameServers,
|
originalNameServers: resolvConfEntries.nameServers,
|
||||||
othersConfigs: others,
|
othersConfigs: resolvConfEntries.others,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +45,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
if !config.RouteAll {
|
if !config.RouteAll {
|
||||||
err = r.restoreHostDNS()
|
err = r.restoreHostDNS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Errorf("restore host dns: %s", err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
|
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
|
||||||
}
|
}
|
||||||
@@ -52,14 +53,21 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
searchDomainList := searchDomains(config)
|
searchDomainList := searchDomains(config)
|
||||||
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
|
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
|
||||||
|
|
||||||
|
options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
|
||||||
|
|
||||||
buf := prepareResolvConfContent(
|
buf := prepareResolvConfContent(
|
||||||
searchDomainList,
|
searchDomainList,
|
||||||
append([]string{config.ServerIP}, r.originalNameServers...),
|
append([]string{config.ServerIP}, r.originalNameServers...),
|
||||||
r.othersConfigs)
|
options)
|
||||||
|
|
||||||
|
// create a backup for unclean shutdown detection before the resolv.conf is changed
|
||||||
|
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
|
||||||
|
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
err = r.applyConfig(buf)
|
err = r.applyConfig(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("apply config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList)
|
log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList)
|
||||||
@@ -67,20 +75,34 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) restoreHostDNS() error {
|
func (r *resolvconf) restoreHostDNS() error {
|
||||||
|
// openresolv only, debian resolvconf doesn't support "-f"
|
||||||
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an error while removing resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||||
|
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
||||||
|
// openresolv only, debian resolvconf doesn't support "-x"
|
||||||
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||||
cmd.Stdin = &content
|
cmd.Stdin = &content
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got an error while applying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||||
|
if err := r.restoreHostDNS(); err != nil {
|
||||||
|
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,10 +31,13 @@ func (r *responseWriter) RemoteAddr() net.Addr {
|
|||||||
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
|
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
|
||||||
buff, err := msg.Pack()
|
buff, err := msg.Pack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("pack: %w", err)
|
||||||
}
|
}
|
||||||
_, err = r.Write(buff)
|
|
||||||
return err
|
if _, err := r.Write(buff); err != nil {
|
||||||
|
return fmt.Errorf("write: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes a raw buffer back to the client.
|
// Write writes a raw buffer back to the client.
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,6 +34,7 @@ type Server interface {
|
|||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
OnUpdatedHostDNSServer(strings []string)
|
OnUpdatedHostDNSServer(strings []string)
|
||||||
SearchDomains() []string
|
SearchDomains() []string
|
||||||
|
ProbeAvailability()
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
type registeredHandlerMap map[string]handlerWithStop
|
||||||
@@ -58,11 +61,14 @@ type DefaultServer struct {
|
|||||||
// make sense on mobile only
|
// make sense on mobile only
|
||||||
searchDomainNotifier *notifier
|
searchDomainNotifier *notifier
|
||||||
iosDnsManager IosDnsManager
|
iosDnsManager IosDnsManager
|
||||||
|
|
||||||
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerWithStop interface {
|
type handlerWithStop interface {
|
||||||
dns.Handler
|
dns.Handler
|
||||||
stop()
|
stop()
|
||||||
|
probeAvailability()
|
||||||
}
|
}
|
||||||
|
|
||||||
type muxUpdate struct {
|
type muxUpdate struct {
|
||||||
@@ -71,7 +77,12 @@ type muxUpdate struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
func NewDefaultServer(
|
||||||
|
ctx context.Context,
|
||||||
|
wgInterface WGIface,
|
||||||
|
customAddress string,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
) (*DefaultServer, error) {
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if customAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
@@ -88,13 +99,20 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
|
|||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||||
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
|
func NewDefaultServerPermanentUpstream(
|
||||||
|
ctx context.Context,
|
||||||
|
wgInterface WGIface,
|
||||||
|
hostsDnsList []string,
|
||||||
|
config nbdns.Config,
|
||||||
|
listener listener.NetworkChangeListener,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
) *DefaultServer {
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||||
ds.permanent = true
|
ds.permanent = true
|
||||||
ds.hostsDnsList = hostsDnsList
|
ds.hostsDnsList = hostsDnsList
|
||||||
ds.addHostRootZone()
|
ds.addHostRootZone()
|
||||||
@@ -106,13 +124,18 @@ func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServerIos returns a new dns server. It optimized for ios
|
// NewDefaultServerIos returns a new dns server. It optimized for ios
|
||||||
func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer {
|
func NewDefaultServerIos(
|
||||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
ctx context.Context,
|
||||||
|
wgInterface WGIface,
|
||||||
|
iosDnsManager IosDnsManager,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
) *DefaultServer {
|
||||||
|
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||||
ds.iosDnsManager = iosDnsManager
|
ds.iosDnsManager = iosDnsManager
|
||||||
return ds
|
return ds
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -122,7 +145,8 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
}
|
}
|
||||||
|
|
||||||
return defaultServer
|
return defaultServer
|
||||||
@@ -140,12 +164,15 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
if s.permanent {
|
if s.permanent {
|
||||||
err = s.service.Listen()
|
err = s.service.Listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("service listen: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.hostManager, err = s.initialize()
|
s.hostManager, err = s.initialize()
|
||||||
return err
|
if err != nil {
|
||||||
|
return fmt.Errorf("initialize: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DnsIP returns the DNS resolver server IP address
|
// DnsIP returns the DNS resolver server IP address
|
||||||
@@ -223,7 +250,7 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := s.applyConfiguration(update); err != nil {
|
if err := s.applyConfiguration(update); err != nil {
|
||||||
return err
|
return fmt.Errorf("apply configuration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.updateSerial = serial
|
s.updateSerial = serial
|
||||||
@@ -248,6 +275,20 @@ func (s *DefaultServer) SearchDomains() []string {
|
|||||||
return searchDomains
|
return searchDomains
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProbeAvailability tests each upstream group's servers for availability
|
||||||
|
// and deactivates the group if no server responds
|
||||||
|
func (s *DefaultServer) ProbeAvailability() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, mux := range s.dnsMuxMap {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(mux handlerWithStop) {
|
||||||
|
defer wg.Done()
|
||||||
|
mux.probeAvailability()
|
||||||
|
}(mux)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
// is the service should be Disabled, we stop the listener or fake resolver
|
// is the service should be Disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
@@ -286,6 +327,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.updateNSGroupStates(update.NameServerGroups)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -325,7 +368,13 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
|
handler, err := newUpstreamResolver(
|
||||||
|
s.ctx,
|
||||||
|
s.wgInterface.Name(),
|
||||||
|
s.wgInterface.Address().IP,
|
||||||
|
s.wgInterface.Address().Network,
|
||||||
|
s.statusRecorder,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -378,6 +427,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return muxUpdates, nil
|
return muxUpdates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -446,14 +496,14 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
|||||||
func (s *DefaultServer) upstreamCallbacks(
|
func (s *DefaultServer) upstreamCallbacks(
|
||||||
nsGroup *nbdns.NameServerGroup,
|
nsGroup *nbdns.NameServerGroup,
|
||||||
handler dns.Handler,
|
handler dns.Handler,
|
||||||
) (deactivate func(), reactivate func()) {
|
) (deactivate func(error), reactivate func()) {
|
||||||
var removeIndex map[string]int
|
var removeIndex map[string]int
|
||||||
deactivate = func() {
|
deactivate = func(err error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
l.Info("temporary deactivate nameservers group due timeout")
|
l.Info("Temporarily deactivating nameservers group due to timeout")
|
||||||
|
|
||||||
removeIndex = make(map[string]int)
|
removeIndex = make(map[string]int)
|
||||||
for _, domain := range nsGroup.Domains {
|
for _, domain := range nsGroup.Domains {
|
||||||
@@ -472,8 +522,11 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.updateNSState(nsGroup, err, false)
|
||||||
|
|
||||||
}
|
}
|
||||||
reactivate = func() {
|
reactivate = func() {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
@@ -488,20 +541,28 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
l.Debug("reactivate temporary Disabled nameserver group")
|
l.Debug("reactivate temporary disabled nameserver group")
|
||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
}
|
}
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
l.WithError(err).Error("reactivate temporary Disabled nameserver group, DNS update apply")
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.updateNSState(nsGroup, nil, true)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) addHostRootZone() {
|
func (s *DefaultServer) addHostRootZone() {
|
||||||
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
|
handler, err := newUpstreamResolver(
|
||||||
|
s.ctx,
|
||||||
|
s.wgInterface.Name(),
|
||||||
|
s.wgInterface.Address().IP,
|
||||||
|
s.wgInterface.Address().Network,
|
||||||
|
s.statusRecorder,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||||
return
|
return
|
||||||
@@ -521,7 +582,50 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
|
|
||||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
|
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
|
||||||
}
|
}
|
||||||
handler.deactivate = func() {}
|
handler.deactivate = func(error) {}
|
||||||
handler.reactivate = func() {}
|
handler.reactivate = func() {}
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||||
|
var states []peer.NSGroupState
|
||||||
|
|
||||||
|
for _, group := range groups {
|
||||||
|
var servers []string
|
||||||
|
for _, ns := range group.NameServers {
|
||||||
|
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
state := peer.NSGroupState{
|
||||||
|
ID: generateGroupKey(group),
|
||||||
|
Servers: servers,
|
||||||
|
Domains: group.Domains,
|
||||||
|
// The probe will determine the state, default enabled
|
||||||
|
Enabled: true,
|
||||||
|
Error: nil,
|
||||||
|
}
|
||||||
|
states = append(states, state)
|
||||||
|
}
|
||||||
|
s.statusRecorder.UpdateDNSStates(states)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) {
|
||||||
|
states := s.statusRecorder.GetDNSStates()
|
||||||
|
id := generateGroupKey(nsGroup)
|
||||||
|
for i, state := range states {
|
||||||
|
if state.ID == id {
|
||||||
|
states[i].Enabled = enabled
|
||||||
|
states[i].Error = err
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.statusRecorder.UpdateDNSStates(states)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||||
|
var servers []string
|
||||||
|
for _, ns := range nsGroup.NameServers {
|
||||||
|
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||||
return newHostManager(s.wgInterface)
|
return newHostManager()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||||
return newHostManager(s.wgInterface)
|
return newHostManager()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||||
return newHostManager(s.wgInterface)
|
return newHostManager(s.wgInterface.Name())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
@@ -59,6 +60,10 @@ func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) {
|
||||||
|
return iface.WGStats{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
var zoneRecords = []nbdns.SimpleRecord{
|
var zoneRecords = []nbdns.SimpleRecord{
|
||||||
{
|
{
|
||||||
Name: "peera.netbird.cloud",
|
Name: "peera.netbird.cloud",
|
||||||
@@ -270,7 +275,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -371,7 +376,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
return
|
return
|
||||||
@@ -466,7 +471,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
}
|
}
|
||||||
@@ -537,6 +542,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
{false, "domain2", false},
|
{false, "domain2", false},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
statusRecorder: &peer.Status{},
|
||||||
}
|
}
|
||||||
|
|
||||||
var domainsUpdate string
|
var domainsUpdate string
|
||||||
@@ -559,7 +565,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
deactivate()
|
deactivate(nil)
|
||||||
expected := "domain0,domain2"
|
expected := "domain0,domain2"
|
||||||
domains := []string{}
|
domains := []string{}
|
||||||
for _, item := range server.currentConfig.Domains {
|
for _, item := range server.currentConfig.Domains {
|
||||||
@@ -597,7 +603,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|||||||
|
|
||||||
var dnsList []string
|
var dnsList []string
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil)
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@@ -621,7 +627,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@@ -713,7 +719,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@@ -744,6 +750,11 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
NSType: nbdns.UDPNameServerType,
|
NSType: nbdns.UDPNameServerType,
|
||||||
Port: 53,
|
Port: 53,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("9.9.9.9"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Domains: []string{"customdomain.com"},
|
Domains: []string{"customdomain.com"},
|
||||||
Primary: false,
|
Primary: false,
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ type serviceViaListener struct {
|
|||||||
customAddr *netip.AddrPort
|
customAddr *netip.AddrPort
|
||||||
server *dns.Server
|
server *dns.Server
|
||||||
listenIP string
|
listenIP string
|
||||||
listenPort int
|
listenPort uint16
|
||||||
listenerIsRunning bool
|
listenerIsRunning bool
|
||||||
listenerFlagLock sync.Mutex
|
listenerFlagLock sync.Mutex
|
||||||
ebpfService ebpfMgr.Manager
|
ebpfService ebpfMgr.Manager
|
||||||
@@ -63,18 +63,9 @@ func (s *serviceViaListener) Listen() error {
|
|||||||
s.listenIP, s.listenPort, err = s.evalListenAddress()
|
s.listenIP, s.listenPort, err = s.evalListenAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to eval runtime address: %s", err)
|
log.Errorf("failed to eval runtime address: %s", err)
|
||||||
return err
|
return fmt.Errorf("eval listen address: %w", err)
|
||||||
}
|
}
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
||||||
|
|
||||||
if s.shouldApplyPortFwd() {
|
|
||||||
s.ebpfService = ebpf.GetEbpfManagerInstance()
|
|
||||||
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
|
|
||||||
s.ebpfService = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.Debugf("starting dns on %s", s.server.Addr)
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
go func() {
|
go func() {
|
||||||
s.setListenerStatus(true)
|
s.setListenerStatus(true)
|
||||||
@@ -128,7 +119,7 @@ func (s *serviceViaListener) RuntimePort() int {
|
|||||||
if s.ebpfService != nil {
|
if s.ebpfService != nil {
|
||||||
return defaultPort
|
return defaultPort
|
||||||
} else {
|
} else {
|
||||||
return s.listenPort
|
return int(s.listenPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,54 +131,112 @@ func (s *serviceViaListener) setListenerStatus(running bool) {
|
|||||||
s.listenerIsRunning = running
|
s.listenerIsRunning = running
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
// evalListenAddress figure out the listen address for the DNS server
|
||||||
ips := []string{defaultIP, customIP}
|
// first check the 53 port availability on WG interface or lo, if not success
|
||||||
|
// pick a random port on WG interface for eBPF, if not success
|
||||||
|
// check the 5053 port availability on WG interface or lo without eBPF usage,
|
||||||
|
func (s *serviceViaListener) evalListenAddress() (string, uint16, error) {
|
||||||
|
if s.customAddr != nil {
|
||||||
|
return s.customAddr.Addr().String(), s.customAddr.Port(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, ok := s.testFreePort(defaultPort)
|
||||||
|
if ok {
|
||||||
|
return ip, defaultPort, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ebpfSrv, port, ok := s.tryToUseeBPF()
|
||||||
|
if ok {
|
||||||
|
s.ebpfService = ebpfSrv
|
||||||
|
return s.wgInterface.Address().IP.String(), port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, ok = s.testFreePort(customPort)
|
||||||
|
if ok {
|
||||||
|
return ip, customPort, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", 0, fmt.Errorf("failed to find a free port for DNS server")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) testFreePort(port int) (string, bool) {
|
||||||
|
var ips []string
|
||||||
if runtime.GOOS != "darwin" {
|
if runtime.GOOS != "darwin" {
|
||||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
ips = []string{s.wgInterface.Address().IP.String(), defaultIP, customIP}
|
||||||
|
} else {
|
||||||
|
ips = []string{defaultIP, customIP}
|
||||||
}
|
}
|
||||||
ports := []int{defaultPort, customPort}
|
|
||||||
for _, port := range ports {
|
for _, ip := range ips {
|
||||||
for _, ip := range ips {
|
if !s.tryToBind(ip, port) {
|
||||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
continue
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
|
||||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err == nil {
|
|
||||||
err = probeListener.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
|
||||||
}
|
|
||||||
return ip, port, nil
|
|
||||||
}
|
|
||||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ip, true
|
||||||
}
|
}
|
||||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) evalListenAddress() (string, int, error) {
|
func (s *serviceViaListener) tryToBind(ip string, port int) bool {
|
||||||
if s.customAddr != nil {
|
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||||
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
}
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err != nil {
|
||||||
return s.getFirstListenerAvailable()
|
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||||
}
|
|
||||||
|
|
||||||
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
|
|
||||||
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
|
|
||||||
// resolution won't work.
|
|
||||||
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
|
|
||||||
// traffic on port 53 and forward it to a local DNS server running on 5053.
|
|
||||||
func (s *serviceViaListener) shouldApplyPortFwd() bool {
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.customAddr != nil {
|
err = probeListener.Close()
|
||||||
return false
|
if err != nil {
|
||||||
}
|
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||||
|
|
||||||
if s.listenPort == defaultPort {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tryToUseeBPF decides whether to apply eBPF program to capture DNS traffic on port 53.
|
||||||
|
// This is needed because on some operating systems if we start a DNS server not on a default port 53,
|
||||||
|
// the domain name resolution won't work. So, in case we are running on Linux and picked a free
|
||||||
|
// port we should fall back to the eBPF solution that will capture traffic on port 53 and forward
|
||||||
|
// it to a local DNS server running on the chosen port.
|
||||||
|
func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := s.generateFreePort() //nolint:staticcheck,unused
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to generate a free port for eBPF DNS forwarder server: %s", err)
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
ebpfSrv := ebpf.GetEbpfManagerInstance()
|
||||||
|
err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port))
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err)
|
||||||
|
return nil, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return ebpfSrv, port, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceViaListener) generateFreePort() (uint16, error) {
|
||||||
|
ok := s.tryToBind(s.wgInterface.Address().IP.String(), customPort)
|
||||||
|
if ok {
|
||||||
|
return customPort, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("0.0.0.0:0"))
|
||||||
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to bind random port for DNS: %s", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort := netip.MustParseAddrPort(probeListener.LocalAddr().String()) // might panic if address is incorrect
|
||||||
|
err = probeListener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to free up DNS port: %s", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return addrPort.Port(), nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (s *serviceViaMemory) Listen() error {
|
|||||||
var err error
|
var err error
|
||||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("filter dns traffice: %w", err)
|
||||||
}
|
}
|
||||||
s.listenerIsRunning = true
|
s.listenerIsRunning = true
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -30,6 +31,8 @@ const (
|
|||||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||||
systemdDbusResolvConfModeForeign = "foreign"
|
systemdDbusResolvConfModeForeign = "foreign"
|
||||||
|
|
||||||
|
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||||
)
|
)
|
||||||
|
|
||||||
type systemdDbusConfigurator struct {
|
type systemdDbusConfigurator struct {
|
||||||
@@ -52,22 +55,22 @@ type systemdDbusLinkDomainsInput struct {
|
|||||||
MatchOnly bool
|
MatchOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
iface, err := net.InterfaceByName(wgInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("get interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("get dbus resolved dest: %w", err)
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
|
|
||||||
var s string
|
var s string
|
||||||
err = obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, iface.Index).Store(&s)
|
err = obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, iface.Index).Store(&s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("get dbus link method: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
||||||
@@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
|||||||
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||||
parsedIP, err := netip.ParseAddr(config.ServerIP)
|
parsedIP, err := netip.ParseAddr(config.ServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to parse ip address, error: %s", err)
|
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
||||||
}
|
}
|
||||||
ipAs4 := parsedIP.As4()
|
ipAs4 := parsedIP.As4()
|
||||||
defaultLinkInput := systemdDbusDNSInput{
|
defaultLinkInput := systemdDbusDNSInput{
|
||||||
@@ -93,7 +96,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
}
|
}
|
||||||
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
|
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.ServerIP, config.ServerPort, err)
|
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -121,7 +124,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||||
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting link as default dns router, failed with error: %s", err)
|
return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
|
||||||
}
|
}
|
||||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||||
Domain: nbdns.RootZone,
|
Domain: nbdns.RootZone,
|
||||||
@@ -132,6 +135,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
||||||
|
// The file content itself is not important for systemd restoration
|
||||||
|
if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil {
|
||||||
|
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||||
err = s.setDomainsForInterface(domainsInput)
|
err = s.setDomainsForInterface(domainsInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -143,7 +152,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error {
|
func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error {
|
||||||
err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput)
|
err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting domains configuration failed with error: %s", err)
|
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
||||||
}
|
}
|
||||||
return s.flushCaches()
|
return s.flushCaches()
|
||||||
}
|
}
|
||||||
@@ -153,17 +162,29 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
|||||||
if !isDbusListenerRunning(systemdResolvedDest, s.dbusLinkObject) {
|
if !isDbusListenerRunning(systemdResolvedDest, s.dbusLinkObject) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this call is required for DNS cleanup, even if it fails
|
||||||
err := s.callLinkMethod(systemdDbusRevertMethodSuffix, nil)
|
err := s.callLinkMethod(systemdDbusRevertMethodSuffix, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to revert link configuration, got error: %s", err)
|
var dbusErr dbus.Error
|
||||||
|
if errors.As(err, &dbusErr) && dbusErr.Name == dbusErrorUnknownObject {
|
||||||
|
// interface is gone already
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||||
|
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
return s.flushCaches()
|
return s.flushCaches()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) flushCaches() error {
|
func (s *systemdDbusConfigurator) flushCaches() error {
|
||||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while attempting to retrieve the object %s, err: %s", systemdDbusObjectNode, err)
|
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
||||||
@@ -171,7 +192,7 @@ func (s *systemdDbusConfigurator) flushCaches() error {
|
|||||||
|
|
||||||
err = obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, dbusDefaultFlag).Store()
|
err = obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, dbusDefaultFlag).Store()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while calling the FlushCaches method with context, err: %s", err)
|
return fmt.Errorf("calling the FlushCaches method with context, err: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -180,7 +201,7 @@ func (s *systemdDbusConfigurator) flushCaches() error {
|
|||||||
func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error {
|
func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error {
|
||||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, s.dbusLinkObject)
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, s.dbusLinkObject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while attempting to retrieve the object, err: %s", err)
|
return fmt.Errorf("attempting to retrieve the object, err: %w", err)
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
|
|
||||||
@@ -194,22 +215,29 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while calling command with context, err: %s", err)
|
return fmt.Errorf("calling command with context, err: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||||
|
if err := s.restoreHostDNS(); err != nil {
|
||||||
|
return fmt.Errorf("restoring dns via systemd: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getSystemdDbusProperty(property string, store any) error {
|
func getSystemdDbusProperty(property string, store any) error {
|
||||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("got error while attempting to retrieve the systemd dns manager object, error: %s", err)
|
return fmt.Errorf("attempting to retrieve the systemd dns manager object, error: %w", err)
|
||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
|
|
||||||
v, e := obj.GetProperty(property)
|
v, e := obj.GetProperty(property)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return fmt.Errorf("got an error getting property %s: %v", property, e)
|
return fmt.Errorf("getting property %s: %w", property, e)
|
||||||
}
|
}
|
||||||
|
|
||||||
return v.Store(store)
|
return v.Store(store)
|
||||||
|
|||||||
5
client/internal/dns/unclean_shutdown_android.go
Normal file
5
client/internal/dns/unclean_shutdown_android.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
func CheckUncleanShutdown(string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
59
client/internal/dns/unclean_shutdown_darwin.go
Normal file
59
client/internal/dns/unclean_shutdown_darwin.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns"
|
||||||
|
|
||||||
|
func CheckUncleanShutdown(string) error {
|
||||||
|
if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil {
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
// no file -> clean shutdown
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("state: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
|
||||||
|
|
||||||
|
manager, err := newHostManager()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create host manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
|
||||||
|
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUncleanShutdownIndicator() error {
|
||||||
|
dir := filepath.Dir(fileUncleanShutdownFileLocation)
|
||||||
|
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
||||||
|
return fmt.Errorf("create dir %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
|
||||||
|
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeUncleanShutdownIndicator() error {
|
||||||
|
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
5
client/internal/dns/unclean_shutdown_ios.go
Normal file
5
client/internal/dns/unclean_shutdown_ios.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
func CheckUncleanShutdown(string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
96
client/internal/dns/unclean_shutdown_linux.go
Normal file
96
client/internal/dns/unclean_shutdown_linux.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||||
|
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CheckUncleanShutdown(wgIface string) error {
|
||||||
|
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
// no file -> clean shutdown
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("state: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation)
|
||||||
|
|
||||||
|
managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
managerFields := strings.Split(string(managerData), ",")
|
||||||
|
if len(managerFields) < 2 {
|
||||||
|
return errors.New("split manager data: insufficient number of fields")
|
||||||
|
}
|
||||||
|
osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1]
|
||||||
|
|
||||||
|
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr)
|
||||||
|
|
||||||
|
// determine os manager type, so we can invoke the respective restore action
|
||||||
|
osManagerType, err := newOsManagerType(osManagerTypeStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("detect previous host manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := newHostManagerFromType(wgIface, osManagerType)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create previous host manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil {
|
||||||
|
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error {
|
||||||
|
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
|
||||||
|
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
||||||
|
return fmt.Errorf("create dir %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := copyFile(sourcePath, fileUncleanShutdownResolvConfLocation); err != nil {
|
||||||
|
return fmt.Errorf("create %s: %w", sourcePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress)
|
||||||
|
|
||||||
|
if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec
|
||||||
|
return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeUncleanShutdownIndicator() error {
|
||||||
|
if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err)
|
||||||
|
}
|
||||||
|
if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
75
client/internal/dns/unclean_shutdown_windows.go
Normal file
75
client/internal/dns/unclean_shutdown_windows.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
netbirdProgramDataLocation = "Netbird"
|
||||||
|
fileUncleanShutdownFile = "unclean_shutdown_dns.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CheckUncleanShutdown(string) error {
|
||||||
|
file := getUncleanShutdownFile()
|
||||||
|
|
||||||
|
if _, err := os.Stat(file); err != nil {
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
// no file -> clean shutdown
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("state: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file)
|
||||||
|
|
||||||
|
guid, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read %s: %w", file, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := newHostManagerWithGuid(string(guid))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create host manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
|
||||||
|
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUncleanShutdownIndicator(guid string) error {
|
||||||
|
file := getUncleanShutdownFile()
|
||||||
|
|
||||||
|
dir := filepath.Dir(file)
|
||||||
|
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
||||||
|
return fmt.Errorf("create dir %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(file, []byte(guid), 0600); err != nil {
|
||||||
|
return fmt.Errorf("create %s: %w", file, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeUncleanShutdownIndicator() error {
|
||||||
|
file := getUncleanShutdownFile()
|
||||||
|
|
||||||
|
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return fmt.Errorf("remove %s: %w", file, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUncleanShutdownFile() string {
|
||||||
|
return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile)
|
||||||
|
}
|
||||||
@@ -11,18 +11,24 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
failsTillDeact = int32(5)
|
failsTillDeact = int32(5)
|
||||||
reactivatePeriod = 30 * time.Second
|
reactivatePeriod = 30 * time.Second
|
||||||
upstreamTimeout = 15 * time.Second
|
upstreamTimeout = 15 * time.Second
|
||||||
|
probeTimeout = 2 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const testRecord = "."
|
||||||
|
|
||||||
type upstreamClient interface {
|
type upstreamClient interface {
|
||||||
exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpstreamResolver interface {
|
type UpstreamResolver interface {
|
||||||
@@ -42,12 +48,13 @@ type upstreamResolverBase struct {
|
|||||||
reactivatePeriod time.Duration
|
reactivatePeriod time.Duration
|
||||||
upstreamTimeout time.Duration
|
upstreamTimeout time.Duration
|
||||||
|
|
||||||
deactivate func()
|
deactivate func(error)
|
||||||
reactivate func()
|
reactivate func()
|
||||||
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
|
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
|
||||||
ctx, cancel := context.WithCancel(parentCTX)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
return &upstreamResolverBase{
|
return &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -55,6 +62,7 @@ func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
|
|||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: upstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,7 +73,10 @@ func (u *upstreamResolverBase) stop() {
|
|||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
defer u.checkUpstreamFails()
|
var err error
|
||||||
|
defer func() {
|
||||||
|
u.checkUpstreamFails(err)
|
||||||
|
}()
|
||||||
|
|
||||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||||
|
|
||||||
@@ -76,11 +87,17 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, upstream := range u.upstreamServers {
|
for _, upstream := range u.upstreamServers {
|
||||||
|
var rm *dns.Msg
|
||||||
|
var t time.Duration
|
||||||
|
|
||||||
rm, t, err := u.upstreamClient.exchange(upstream, r)
|
func() {
|
||||||
|
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||||
|
defer cancel()
|
||||||
|
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
|
||||||
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == context.DeadlineExceeded || isTimeout(err) {
|
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
log.WithError(err).WithField("upstream", upstream).
|
||||||
Warn("got an error while connecting to upstream")
|
Warn("got an error while connecting to upstream")
|
||||||
continue
|
continue
|
||||||
@@ -122,7 +139,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
// If fails count is greater that failsTillDeact, upstream resolving
|
// If fails count is greater that failsTillDeact, upstream resolving
|
||||||
// will be disabled for reactivatePeriod, after that time period fails counter
|
// will be disabled for reactivatePeriod, after that time period fails counter
|
||||||
// will be reset and upstream will be reactivated.
|
// will be reset and upstream will be reactivated.
|
||||||
func (u *upstreamResolverBase) checkUpstreamFails() {
|
func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
||||||
u.mutex.Lock()
|
u.mutex.Lock()
|
||||||
defer u.mutex.Unlock()
|
defer u.mutex.Unlock()
|
||||||
|
|
||||||
@@ -134,13 +151,52 @@ func (u *upstreamResolverBase) checkUpstreamFails() {
|
|||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
// todo test the deactivation logic, it seems to affect the client
|
}
|
||||||
if runtime.GOOS != "ios" {
|
|
||||||
log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod)
|
u.disable(err)
|
||||||
u.deactivate()
|
}
|
||||||
u.disabled = true
|
|
||||||
go u.waitUntilResponse()
|
// probeAvailability tests all upstream servers simultaneously and
|
||||||
}
|
// disables the resolver if none work
|
||||||
|
func (u *upstreamResolverBase) probeAvailability() {
|
||||||
|
u.mutex.Lock()
|
||||||
|
defer u.mutex.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-u.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var mu sync.Mutex
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
var errors *multierror.Error
|
||||||
|
for _, upstream := range u.upstreamServers {
|
||||||
|
upstream := upstream
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := u.testNameserver(upstream)
|
||||||
|
if err != nil {
|
||||||
|
errors = multierror.Append(errors, err)
|
||||||
|
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
success = true
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// didn't find a working upstream server, let's disable and try later
|
||||||
|
if !success {
|
||||||
|
u.disable(errors.ErrorOrNil())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,8 +212,6 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
Clock: backoff.SystemClock,
|
Clock: backoff.SystemClock,
|
||||||
}
|
}
|
||||||
|
|
||||||
r := new(dns.Msg).SetQuestion("netbird.io.", dns.TypeA)
|
|
||||||
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
@@ -165,17 +219,17 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
for _, upstream := range u.upstreamServers {
|
for _, upstream := range u.upstreamServers {
|
||||||
_, _, err = u.upstreamClient.exchange(upstream, r)
|
if err := u.testNameserver(upstream); err != nil {
|
||||||
|
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||||
if err == nil {
|
} else {
|
||||||
|
// at least one upstream server is available, stop probing
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("checking connectivity with upstreams %s failed with error: %s. Retrying in %s", err, u.upstreamServers, exponentialBackOff.NextBackOff())
|
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff())
|
||||||
return fmt.Errorf("got an error from upstream check call")
|
return fmt.Errorf("upstream check call error")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := backoff.Retry(operation, exponentialBackOff)
|
err := backoff.Retry(operation, exponentialBackOff)
|
||||||
@@ -200,3 +254,27 @@ func isTimeout(err error) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) disable(err error) {
|
||||||
|
if u.disabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo test the deactivation logic, it seems to affect the client
|
||||||
|
if runtime.GOOS != "ios" {
|
||||||
|
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||||
|
u.deactivate(err)
|
||||||
|
u.disabled = true
|
||||||
|
go u.waitUntilResponse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) testNameserver(server string) error {
|
||||||
|
ctx, cancel := context.WithTimeout(u.ctx, probeTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||||
|
|
||||||
|
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolverIOS struct {
|
type upstreamResolverIOS struct {
|
||||||
@@ -20,8 +22,14 @@ type upstreamResolverIOS struct {
|
|||||||
iIndex int
|
iIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverIOS, error) {
|
func newUpstreamResolver(
|
||||||
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
|
ctx context.Context,
|
||||||
|
interfaceName string,
|
||||||
|
ip net.IP,
|
||||||
|
net *net.IPNet,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
) (*upstreamResolverIOS, error) {
|
||||||
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||||
|
|
||||||
index, err := getInterfaceIndex(interfaceName)
|
index, err := getInterfaceIndex(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -40,30 +48,38 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net
|
|||||||
return ios, nil
|
return ios, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
client := &dns.Client{}
|
client := &dns.Client{}
|
||||||
upstreamHost, _, err := net.SplitHostPort(upstream)
|
upstreamHost, _, err := net.SplitHostPort(upstream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while parsing upstream host: %s", err)
|
log.Errorf("error while parsing upstream host: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timeout := upstreamTimeout
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
timeout = time.Until(deadline)
|
||||||
|
}
|
||||||
|
client.DialTimeout = timeout
|
||||||
|
|
||||||
upstreamIP := net.ParseIP(upstreamHost)
|
upstreamIP := net.ParseIP(upstreamHost)
|
||||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
||||||
log.Debugf("using private client to query upstream: %s", upstream)
|
log.Debugf("using private client to query upstream: %s", upstream)
|
||||||
client = u.getClientPrivate()
|
client = u.getClientPrivate(timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||||
return client.Exchange(r, upstream)
|
return client.Exchange(r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||||
// This method is needed for iOS
|
// This method is needed for iOS
|
||||||
func (u *upstreamResolverIOS) getClientPrivate() *dns.Client {
|
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
LocalAddr: &net.UDPAddr{
|
LocalAddr: &net.UDPAddr{
|
||||||
IP: u.lIP,
|
IP: u.lIP,
|
||||||
Port: 0, // Let the OS pick a free port
|
Port: 0, // Let the OS pick a free port
|
||||||
},
|
},
|
||||||
Timeout: upstreamTimeout,
|
Timeout: dialTimeout,
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
var operr error
|
var operr error
|
||||||
fn := func(s uintptr) {
|
fn := func(s uintptr) {
|
||||||
|
|||||||
@@ -8,14 +8,22 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolverNonIOS struct {
|
type upstreamResolverNonIOS struct {
|
||||||
*upstreamResolverBase
|
*upstreamResolverBase
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverNonIOS, error) {
|
func newUpstreamResolver(
|
||||||
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
|
ctx context.Context,
|
||||||
|
_ string,
|
||||||
|
_ net.IP,
|
||||||
|
_ *net.IPNet,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
) (*upstreamResolverNonIOS, error) {
|
||||||
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||||
nonIOS := &upstreamResolverNonIOS{
|
nonIOS := &upstreamResolverNonIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
}
|
}
|
||||||
@@ -23,10 +31,7 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net
|
|||||||
return nonIOS, nil
|
return nonIOS, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverNonIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolverNonIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
upstreamExchangeClient := &dns.Client{}
|
upstreamExchangeClient := &dns.Client{}
|
||||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||||
rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
cancel()
|
|
||||||
return rm, t, err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{})
|
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil)
|
||||||
resolver.upstreamServers = testCase.InputServers
|
resolver.upstreamServers = testCase.InputServers
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
@@ -105,8 +105,8 @@ type mockUpstreamResolver struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeContext mock implementation of ExchangeContext from upstreamResolver
|
// exchange mock implementation of exchange from upstreamResolver
|
||||||
func (c mockUpstreamResolver) exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
|
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||||
return c.r, c.rtt, c.err
|
return c.r, c.rtt, c.err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
failed := false
|
failed := false
|
||||||
resolver.deactivate = func() {
|
resolver.deactivate = func(error) {
|
||||||
failed = true
|
failed = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,4 +11,5 @@ type WGIface interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() iface.PacketFilter
|
GetFilter() iface.PacketFilter
|
||||||
GetDevice() *iface.DeviceWrapper
|
GetDevice() *iface.DeviceWrapper
|
||||||
|
GetStats(peerKey string) (iface.WGStats, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,5 +9,6 @@ type WGIface interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() iface.PacketFilter
|
GetFilter() iface.PacketFilter
|
||||||
GetDevice() *iface.DeviceWrapper
|
GetDevice() *iface.DeviceWrapper
|
||||||
|
GetStats(peerKey string) (iface.WGStats, error)
|
||||||
GetInterfaceGUIDString() (string, error)
|
GetInterfaceGUIDString() (string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
||||||
log.Debugf("load ebpf DNS forwarder: address: %s:%d", ip, dnsPort)
|
log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort)
|
||||||
tf.lock.Lock()
|
tf.lock.Lock()
|
||||||
defer tf.lock.Unlock()
|
defer tf.lock.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
# Debug
|
# DNS forwarder
|
||||||
|
|
||||||
|
The agent attach the XDP program to the lo device. We can not use fake address in eBPF because the
|
||||||
|
traffic does not appear in the eBPF program. The program capture the traffic on wg_ip:53 and
|
||||||
|
overwrite in it the destination port to 5053.
|
||||||
|
|
||||||
|
# Debug
|
||||||
|
|
||||||
The CONFIG_BPF_EVENTS kernel module is required for bpf_printk.
|
The CONFIG_BPF_EVENTS kernel module is required for bpf_printk.
|
||||||
Apply this code to use bpf_printk
|
Apply this code to use bpf_printk
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||||
@@ -78,7 +79,10 @@ type EngineConfig struct {
|
|||||||
|
|
||||||
CustomDNSAddress string
|
CustomDNSAddress string
|
||||||
|
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
|
RosenpassPermissive bool
|
||||||
|
|
||||||
|
ServerSSHAllowed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@@ -89,6 +93,10 @@ type Engine struct {
|
|||||||
mgmClient mgm.Client
|
mgmClient mgm.Client
|
||||||
// peerConns is a map that holds all the peers that are known to this peer
|
// peerConns is a map that holds all the peers that are known to this peer
|
||||||
peerConns map[string]*peer.Conn
|
peerConns map[string]*peer.Conn
|
||||||
|
|
||||||
|
beforePeerHook peer.BeforeAddPeerHookFunc
|
||||||
|
afterPeerHook peer.AfterRemovePeerHookFunc
|
||||||
|
|
||||||
// rpManager is a Rosenpass manager
|
// rpManager is a Rosenpass manager
|
||||||
rpManager *rosenpass.Manager
|
rpManager *rosenpass.Manager
|
||||||
|
|
||||||
@@ -125,6 +133,11 @@ type Engine struct {
|
|||||||
acl acl.Manager
|
acl acl.Manager
|
||||||
|
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
|
||||||
|
mgmProbe *Probe
|
||||||
|
signalProbe *Probe
|
||||||
|
relayProbe *Probe
|
||||||
|
wgProbe *Probe
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -135,11 +148,43 @@ type Peer struct {
|
|||||||
|
|
||||||
// NewEngine creates a new Connection Engine
|
// NewEngine creates a new Connection Engine
|
||||||
func NewEngine(
|
func NewEngine(
|
||||||
ctx context.Context, cancel context.CancelFunc,
|
ctx context.Context,
|
||||||
signalClient signal.Client, mgmClient mgm.Client,
|
cancel context.CancelFunc,
|
||||||
config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
|
signalClient signal.Client,
|
||||||
|
mgmClient mgm.Client,
|
||||||
|
config *EngineConfig,
|
||||||
|
mobileDep MobileDependency,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
|
return NewEngineWithProbes(
|
||||||
|
ctx,
|
||||||
|
cancel,
|
||||||
|
signalClient,
|
||||||
|
mgmClient,
|
||||||
|
config,
|
||||||
|
mobileDep,
|
||||||
|
statusRecorder,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEngineWithProbes creates a new Connection Engine with probes attached
|
||||||
|
func NewEngineWithProbes(
|
||||||
|
ctx context.Context,
|
||||||
|
cancel context.CancelFunc,
|
||||||
|
signalClient signal.Client,
|
||||||
|
mgmClient mgm.Client,
|
||||||
|
config *EngineConfig,
|
||||||
|
mobileDep MobileDependency,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
mgmProbe *Probe,
|
||||||
|
signalProbe *Probe,
|
||||||
|
relayProbe *Probe,
|
||||||
|
wgProbe *Probe,
|
||||||
|
) *Engine {
|
||||||
return &Engine{
|
return &Engine{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
@@ -155,6 +200,10 @@ func NewEngine(
|
|||||||
sshServerFunc: nbssh.DefaultSSHServer,
|
sshServerFunc: nbssh.DefaultSSHServer,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
|
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
|
||||||
|
mgmProbe: mgmProbe,
|
||||||
|
signalProbe: signalProbe,
|
||||||
|
relayProbe: relayProbe,
|
||||||
|
wgProbe: wgProbe,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,38 +234,51 @@ func (e *Engine) Start() error {
|
|||||||
|
|
||||||
wgIface, err := e.newWgIface()
|
wgIface, err := e.newWgIface()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err.Error())
|
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
||||||
return err
|
return fmt.Errorf("new wg interface: %w", err)
|
||||||
}
|
}
|
||||||
e.wgInterface = wgIface
|
e.wgInterface = wgIface
|
||||||
|
|
||||||
if e.config.RosenpassEnabled {
|
if e.config.RosenpassEnabled {
|
||||||
log.Infof("rosenpass is enabled")
|
log.Infof("rosenpass is enabled")
|
||||||
|
if e.config.RosenpassPermissive {
|
||||||
|
log.Infof("running rosenpass in permissive mode")
|
||||||
|
} else {
|
||||||
|
log.Infof("running rosenpass in strict mode")
|
||||||
|
}
|
||||||
e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName)
|
e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
err := e.rpManager.Run()
|
err := e.rpManager.Run()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return err
|
return fmt.Errorf("create dns server: %w", err)
|
||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
||||||
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to initialize route manager: %s", err)
|
||||||
|
} else {
|
||||||
|
e.beforePeerHook = beforePeerHook
|
||||||
|
e.afterPeerHook = afterPeerHook
|
||||||
|
}
|
||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
|
|
||||||
err = e.wgInterfaceCreate()
|
err = e.wgInterfaceCreate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||||
e.close()
|
e.close()
|
||||||
return err
|
return fmt.Errorf("create wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
|
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
|
||||||
@@ -228,7 +290,7 @@ func (e *Engine) Start() error {
|
|||||||
err = e.routeManager.EnableServerRouter(e.firewall)
|
err = e.routeManager.EnableServerRouter(e.firewall)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return err
|
return fmt.Errorf("enable server router: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,7 +298,7 @@ func (e *Engine) Start() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||||
e.close()
|
e.close()
|
||||||
return err
|
return fmt.Errorf("up wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
@@ -246,11 +308,12 @@ func (e *Engine) Start() error {
|
|||||||
err = e.dnsServer.Initialize()
|
err = e.dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return err
|
return fmt.Errorf("initialize dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.receiveSignalEvents()
|
e.receiveSignalEvents()
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
|
e.receiveProbeEvents()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -439,44 +502,52 @@ func isNil(server nbssh.Server) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||||
if sshConf.GetSshEnabled() {
|
|
||||||
if runtime.GOOS == "windows" {
|
if !e.config.ServerSSHAllowed {
|
||||||
log.Warnf("running SSH server on Windows is not supported")
|
log.Warnf("running SSH server is not permitted")
|
||||||
return nil
|
return nil
|
||||||
}
|
} else {
|
||||||
// start SSH server if it wasn't running
|
|
||||||
if isNil(e.sshServer) {
|
if sshConf.GetSshEnabled() {
|
||||||
// nil sshServer means it has not yet been started
|
if runtime.GOOS == "windows" {
|
||||||
var err error
|
log.Warnf("running SSH server on Windows is not supported")
|
||||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
|
return nil
|
||||||
fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
go func() {
|
// start SSH server if it wasn't running
|
||||||
// blocking
|
if isNil(e.sshServer) {
|
||||||
err = e.sshServer.Start()
|
// nil sshServer means it has not yet been started
|
||||||
|
var err error
|
||||||
|
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
|
||||||
|
fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// will throw error when we stop it even if it is a graceful stop
|
return err
|
||||||
log.Debugf("stopped SSH server with error %v", err)
|
|
||||||
}
|
}
|
||||||
e.syncMsgMux.Lock()
|
go func() {
|
||||||
defer e.syncMsgMux.Unlock()
|
// blocking
|
||||||
e.sshServer = nil
|
err = e.sshServer.Start()
|
||||||
log.Infof("stopped SSH server")
|
if err != nil {
|
||||||
}()
|
// will throw error when we stop it even if it is a graceful stop
|
||||||
} else {
|
log.Debugf("stopped SSH server with error %v", err)
|
||||||
log.Debugf("SSH server is already running")
|
}
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
e.sshServer = nil
|
||||||
|
log.Infof("stopped SSH server")
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
log.Debugf("SSH server is already running")
|
||||||
|
}
|
||||||
|
} else if !isNil(e.sshServer) {
|
||||||
|
// Disable SSH server request, so stop it if it was running
|
||||||
|
err := e.sshServer.Stop()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to stop SSH server %v", err)
|
||||||
|
}
|
||||||
|
e.sshServer = nil
|
||||||
}
|
}
|
||||||
} else if !isNil(e.sshServer) {
|
return nil
|
||||||
// Disable SSH server request, so stop it if it was running
|
|
||||||
err := e.sshServer.Stop()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to stop SSH server %v", err)
|
|
||||||
}
|
|
||||||
e.sshServer = nil
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
@@ -512,9 +583,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||||
func (e *Engine) receiveManagementEvents() {
|
func (e *Engine) receiveManagementEvents() {
|
||||||
go func() {
|
go func() {
|
||||||
err := e.mgmClient.Sync(func(update *mgmProto.SyncResponse) error {
|
err := e.mgmClient.Sync(e.handleSync)
|
||||||
return e.handleSync(update)
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
@@ -644,8 +713,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
if e.acl != nil {
|
if e.acl != nil {
|
||||||
e.acl.ApplyFiltering(networkMap)
|
e.acl.ApplyFiltering(networkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
|
|
||||||
|
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||||
|
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||||
|
e.dnsServer.ProbeAvailability()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -720,6 +794,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
|||||||
FQDN: offlinePeer.GetFqdn(),
|
FQDN: offlinePeer.GetFqdn(),
|
||||||
ConnStatus: peer.StatusDisconnected,
|
ConnStatus: peer.StatusDisconnected,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||||
@@ -743,10 +818,15 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
if _, ok := e.peerConns[peerKey]; !ok {
|
if _, ok := e.peerConns[peerKey]; !ok {
|
||||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("create peer connection: %w", err)
|
||||||
}
|
}
|
||||||
e.peerConns[peerKey] = conn
|
e.peerConns[peerKey] = conn
|
||||||
|
|
||||||
|
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||||
|
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||||
|
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||||
|
}
|
||||||
|
|
||||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||||
@@ -815,7 +895,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
PreSharedKey: e.config.PreSharedKey,
|
PreSharedKey: e.config.PreSharedKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.config.RosenpassEnabled {
|
if e.config.RosenpassEnabled && !e.config.RosenpassPermissive {
|
||||||
lk := []byte(e.config.WgPrivateKey.PublicKey().String())
|
lk := []byte(e.config.WgPrivateKey.PublicKey().String())
|
||||||
rk := []byte(wgConfig.RemoteKey)
|
rk := []byte(wgConfig.RemoteKey)
|
||||||
var keyInput []byte
|
var keyInput []byte
|
||||||
@@ -1035,6 +1115,15 @@ func (e *Engine) close() {
|
|||||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||||
|
if e.dnsServer != nil {
|
||||||
|
e.dnsServer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.routeManager != nil {
|
||||||
|
e.routeManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
if e.wgInterface != nil {
|
if e.wgInterface != nil {
|
||||||
if err := e.wgInterface.Close(); err != nil {
|
if err := e.wgInterface.Close(); err != nil {
|
||||||
@@ -1049,14 +1138,6 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.routeManager != nil {
|
|
||||||
e.routeManager.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.dnsServer != nil {
|
|
||||||
e.dnsServer.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
err := e.firewall.Reset()
|
err := e.firewall.Reset()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1126,14 +1207,21 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
dnsServer := dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener)
|
dnsServer := dns.NewDefaultServerPermanentUpstream(
|
||||||
|
e.ctx,
|
||||||
|
e.wgInterface,
|
||||||
|
e.mobileDep.HostDNSAddresses,
|
||||||
|
*dnsConfig,
|
||||||
|
e.mobileDep.NetworkChangeListener,
|
||||||
|
e.statusRecorder,
|
||||||
|
)
|
||||||
go e.mobileDep.DnsReadyListener.OnReady()
|
go e.mobileDep.DnsReadyListener.OnReady()
|
||||||
return routes, dnsServer, nil
|
return routes, dnsServer, nil
|
||||||
case "ios":
|
case "ios":
|
||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
||||||
return nil, dnsServer, nil
|
return nil, dnsServer, nil
|
||||||
default:
|
default:
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -1175,3 +1263,69 @@ func (e *Engine) getRosenpassAddr() string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) receiveProbeEvents() {
|
||||||
|
if e.signalProbe != nil {
|
||||||
|
go e.signalProbe.Receive(e.ctx, func() bool {
|
||||||
|
healthy := e.signal.IsHealthy()
|
||||||
|
log.Debugf("received signal probe request, healthy: %t", healthy)
|
||||||
|
return healthy
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.mgmProbe != nil {
|
||||||
|
go e.mgmProbe.Receive(e.ctx, func() bool {
|
||||||
|
healthy := e.mgmClient.IsHealthy()
|
||||||
|
log.Debugf("received management probe request, healthy: %t", healthy)
|
||||||
|
return healthy
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.relayProbe != nil {
|
||||||
|
go e.relayProbe.Receive(e.ctx, func() bool {
|
||||||
|
healthy := true
|
||||||
|
|
||||||
|
results := append(e.probeSTUNs(), e.probeTURNs()...)
|
||||||
|
e.statusRecorder.UpdateRelayStates(results)
|
||||||
|
|
||||||
|
// A single failed server will result in a "failed" probe
|
||||||
|
for _, res := range results {
|
||||||
|
if res.Err != nil {
|
||||||
|
healthy = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("received relay probe request, healthy: %t", healthy)
|
||||||
|
return healthy
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.wgProbe != nil {
|
||||||
|
go e.wgProbe.Receive(e.ctx, func() bool {
|
||||||
|
log.Debug("received wg probe request")
|
||||||
|
|
||||||
|
for _, peer := range e.peerConns {
|
||||||
|
key := peer.GetKey()
|
||||||
|
wgStats, err := peer.GetConf().WgConfig.WgInterface.GetStats(key)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
||||||
|
}
|
||||||
|
// wgStats could be zero value, in which case we just reset the stats
|
||||||
|
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||||
|
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) probeSTUNs() []relay.ProbeResult {
|
||||||
|
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, e.STUNs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||||
|
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
@@ -70,10 +71,11 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||||
WgIfaceName: "utun101",
|
WgIfaceName: "utun101",
|
||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
|
ServerSSHAllowed: true,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@@ -1049,8 +1051,8 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||||
eventStore, false)
|
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,12 +20,15 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface/bind"
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
iceKeepAliveDefault = 4 * time.Second
|
iceKeepAliveDefault = 4 * time.Second
|
||||||
iceDisconnectedTimeoutDefault = 6 * time.Second
|
iceDisconnectedTimeoutDefault = 6 * time.Second
|
||||||
|
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
||||||
|
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
||||||
|
|
||||||
defaultWgKeepAlive = 25 * time.Second
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
)
|
)
|
||||||
@@ -98,6 +101,9 @@ type IceCredentials struct {
|
|||||||
Pwd string
|
Pwd string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
|
||||||
|
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -130,8 +136,16 @@ type Conn struct {
|
|||||||
remoteModeCh chan ModeMessage
|
remoteModeCh chan ModeMessage
|
||||||
meta meta
|
meta meta
|
||||||
|
|
||||||
adapter iface.TunAdapter
|
adapter iface.TunAdapter
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
|
sentExtraSrflx bool
|
||||||
|
|
||||||
|
remoteEndpoint *net.UDPAddr
|
||||||
|
remoteConn *ice.Conn
|
||||||
|
|
||||||
|
connID nbnet.ConnectionID
|
||||||
|
beforeAddPeerHooks []BeforeAddPeerHookFunc
|
||||||
|
afterRemovePeerHooks []AfterRemovePeerHookFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// meta holds meta information about a connection
|
// meta holds meta information about a connection
|
||||||
@@ -192,20 +206,22 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
|
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
|
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||||
|
|
||||||
agentConfig := &ice.AgentConfig{
|
agentConfig := &ice.AgentConfig{
|
||||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||||
Urls: conn.config.StunTurn,
|
Urls: conn.config.StunTurn,
|
||||||
CandidateTypes: conn.candidateTypes(),
|
CandidateTypes: conn.candidateTypes(),
|
||||||
FailedTimeout: &failedTimeout,
|
FailedTimeout: &failedTimeout,
|
||||||
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
||||||
UDPMux: conn.config.UDPMux,
|
UDPMux: conn.config.UDPMux,
|
||||||
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
||||||
NAT1To1IPs: conn.config.NATExternalIPs,
|
NAT1To1IPs: conn.config.NATExternalIPs,
|
||||||
Net: transportNet,
|
Net: transportNet,
|
||||||
DisconnectedTimeout: &iceDisconnectedTimeout,
|
DisconnectedTimeout: &iceDisconnectedTimeout,
|
||||||
KeepaliveInterval: &iceKeepAlive,
|
KeepaliveInterval: &iceKeepAlive,
|
||||||
|
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.config.DisableIPv6Discovery {
|
if conn.config.DisableIPv6Discovery {
|
||||||
@@ -213,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.agent, err = ice.NewAgent(agentConfig)
|
conn.agent, err = ice.NewAgent(agentConfig)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -233,6 +248,17 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = conn.agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
|
||||||
|
err := conn.statusRecorder.UpdateLatency(conn.config.Key, p.Latency())
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to update latency for peer %s: %s", conn.config.Key, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed setting binding response callback: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -258,6 +284,7 @@ func (conn *Conn) Open() error {
|
|||||||
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -317,6 +344,7 @@ func (conn *Conn) Open() error {
|
|||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -347,6 +375,9 @@ func (conn *Conn) Open() error {
|
|||||||
if remoteOfferAnswer.WgListenPort != 0 {
|
if remoteOfferAnswer.WgListenPort != 0 {
|
||||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.remoteConn = remoteConn
|
||||||
|
|
||||||
// the ice connection has been established successfully so we are ready to start the proxy
|
// the ice connection has been established successfully so we are ready to start the proxy
|
||||||
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
|
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
|
||||||
remoteOfferAnswer.RosenpassAddr)
|
remoteOfferAnswer.RosenpassAddr)
|
||||||
@@ -371,6 +402,14 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
|||||||
return candidate.Type() == ice.CandidateTypeRelay
|
return candidate.Type() == ice.CandidateTypeRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
|
||||||
|
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
|
||||||
|
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||||
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
|
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
@@ -396,6 +435,15 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
|||||||
}
|
}
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||||
|
conn.remoteEndpoint = endpointUdpAddr
|
||||||
|
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||||
|
|
||||||
|
conn.connID = nbnet.GenerateConnID()
|
||||||
|
for _, hook := range conn.beforeAddPeerHooks {
|
||||||
|
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
|
||||||
|
log.Errorf("Before add peer hook failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -406,14 +454,22 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.status = StatusConnected
|
conn.status = StatusConnected
|
||||||
|
rosenpassEnabled := false
|
||||||
|
if remoteRosenpassPubKey != nil {
|
||||||
|
rosenpassEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
LocalIceCandidateType: pair.Local.Type().String(),
|
LocalIceCandidateType: pair.Local.Type().String(),
|
||||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||||
Direct: !isRelayCandidate(pair.Local),
|
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||||
|
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||||
|
Direct: !isRelayCandidate(pair.Local),
|
||||||
|
RosenpassEnabled: rosenpassEnabled,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||||
peerState.Relayed = true
|
peerState.Relayed = true
|
||||||
@@ -462,6 +518,8 @@ func (conn *Conn) cleanup() error {
|
|||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
|
conn.sentExtraSrflx = false
|
||||||
|
|
||||||
var err1, err2, err3 error
|
var err1, err2, err3 error
|
||||||
if conn.agent != nil {
|
if conn.agent != nil {
|
||||||
err1 = conn.agent.Close()
|
err1 = conn.agent.Close()
|
||||||
@@ -478,6 +536,15 @@ func (conn *Conn) cleanup() error {
|
|||||||
// todo: is it problem if we try to remove a peer what is never existed?
|
// todo: is it problem if we try to remove a peer what is never existed?
|
||||||
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
|
||||||
|
if conn.connID != "" {
|
||||||
|
for _, hook := range conn.afterRemovePeerHooks {
|
||||||
|
if err := hook(conn.connID); err != nil {
|
||||||
|
log.Errorf("After remove peer hook failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.connID = ""
|
||||||
|
|
||||||
if conn.notifyDisconnected != nil {
|
if conn.notifyDisconnected != nil {
|
||||||
conn.notifyDisconnected()
|
conn.notifyDisconnected()
|
||||||
conn.notifyDisconnected = nil
|
conn.notifyDisconnected = nil
|
||||||
@@ -493,6 +560,7 @@ func (conn *Conn) cleanup() error {
|
|||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: conn.status,
|
ConnStatus: conn.status,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -500,6 +568,9 @@ func (conn *Conn) cleanup() error {
|
|||||||
// todo rethink status updates
|
// todo rethink status updates
|
||||||
log.Debugf("error while updating peer's %s state, err: %v", conn.config.Key, err)
|
log.Debugf("error while updating peer's %s state, err: %v", conn.config.Key, err)
|
||||||
}
|
}
|
||||||
|
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil {
|
||||||
|
log.Debugf("failed to reset wireguard stats for peer %s: %s", conn.config.Key, err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
log.Debugf("cleaned up connection to peer %s", conn.config.Key)
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
@@ -552,6 +623,30 @@ func (conn *Conn) onICECandidate(candidate ice.Candidate) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
|
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||||
|
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||||
|
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
||||||
|
relatedAdd := candidate.RelatedAddress()
|
||||||
|
extraSrflx, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||||
|
Network: candidate.NetworkType().String(),
|
||||||
|
Address: candidate.Address(),
|
||||||
|
Port: relatedAdd.Port,
|
||||||
|
Component: candidate.Component(),
|
||||||
|
RelAddr: relatedAdd.Address,
|
||||||
|
RelPort: relatedAdd.Port,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = conn.signalCandidate(extraSrflx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.sentExtraSrflx = true
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,9 +10,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
|
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
|
||||||
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
|
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
|
||||||
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
|
envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC"
|
||||||
|
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
|
||||||
)
|
)
|
||||||
|
|
||||||
func iceKeepAlive() time.Duration {
|
func iceKeepAlive() time.Duration {
|
||||||
@@ -21,7 +22,7 @@ func iceKeepAlive() time.Duration {
|
|||||||
return iceKeepAliveDefault
|
return iceKeepAliveDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv)
|
log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv)
|
||||||
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
|
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
|
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
|
||||||
@@ -37,7 +38,7 @@ func iceDisconnectedTimeout() time.Duration {
|
|||||||
return iceDisconnectedTimeoutDefault
|
return iceDisconnectedTimeoutDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
|
log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
|
||||||
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
|
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
|
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
|
||||||
@@ -47,6 +48,22 @@ func iceDisconnectedTimeout() time.Duration {
|
|||||||
return time.Duration(disconnectedTimeoutSec) * time.Second
|
return time.Duration(disconnectedTimeoutSec) * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func iceRelayAcceptanceMinWait() time.Duration {
|
||||||
|
iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec)
|
||||||
|
if iceRelayAcceptanceMinWaitEnv == "" {
|
||||||
|
return iceRelayAcceptanceMinWaitDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv)
|
||||||
|
disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault)
|
||||||
|
return iceRelayAcceptanceMinWaitDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Duration(disconnectedTimeoutSec) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
func hasICEForceRelayConn() bool {
|
func hasICEForceRelayConn() bool {
|
||||||
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
|
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
|
||||||
return strings.ToLower(disconnectedTimeoutEnv) == "true"
|
return strings.ToLower(disconnectedTimeoutEnv) == "true"
|
||||||
|
|||||||
@@ -4,19 +4,65 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
// State contains the latest state of a peer
|
// State contains the latest state of a peer
|
||||||
type State struct {
|
type State struct {
|
||||||
IP string
|
Mux *sync.RWMutex
|
||||||
PubKey string
|
IP string
|
||||||
FQDN string
|
PubKey string
|
||||||
ConnStatus ConnStatus
|
FQDN string
|
||||||
ConnStatusUpdate time.Time
|
ConnStatus ConnStatus
|
||||||
Relayed bool
|
ConnStatusUpdate time.Time
|
||||||
Direct bool
|
Relayed bool
|
||||||
LocalIceCandidateType string
|
Direct bool
|
||||||
RemoteIceCandidateType string
|
LocalIceCandidateType string
|
||||||
|
RemoteIceCandidateType string
|
||||||
|
LocalIceCandidateEndpoint string
|
||||||
|
RemoteIceCandidateEndpoint string
|
||||||
|
LastWireguardHandshake time.Time
|
||||||
|
BytesTx int64
|
||||||
|
BytesRx int64
|
||||||
|
Latency time.Duration
|
||||||
|
RosenpassEnabled bool
|
||||||
|
routes map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRoute add a single route to routes map
|
||||||
|
func (s *State) AddRoute(network string) {
|
||||||
|
s.Mux.Lock()
|
||||||
|
if s.routes == nil {
|
||||||
|
s.routes = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
s.routes[network] = struct{}{}
|
||||||
|
s.Mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRoutes set state routes
|
||||||
|
func (s *State) SetRoutes(routes map[string]struct{}) {
|
||||||
|
s.Mux.Lock()
|
||||||
|
s.routes = routes
|
||||||
|
s.Mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRoute removes a route from the network amp
|
||||||
|
func (s *State) DeleteRoute(network string) {
|
||||||
|
s.Mux.Lock()
|
||||||
|
delete(s.routes, network)
|
||||||
|
s.Mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutes return routes map
|
||||||
|
func (s *State) GetRoutes() map[string]struct{} {
|
||||||
|
s.Mux.RLock()
|
||||||
|
defer s.Mux.RUnlock()
|
||||||
|
return s.routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalPeerState contains the latest state of the local peer
|
// LocalPeerState contains the latest state of the local peer
|
||||||
@@ -25,18 +71,37 @@ type LocalPeerState struct {
|
|||||||
PubKey string
|
PubKey string
|
||||||
KernelInterface bool
|
KernelInterface bool
|
||||||
FQDN string
|
FQDN string
|
||||||
|
Routes map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignalState contains the latest state of a signal connection
|
// SignalState contains the latest state of a signal connection
|
||||||
type SignalState struct {
|
type SignalState struct {
|
||||||
URL string
|
URL string
|
||||||
Connected bool
|
Connected bool
|
||||||
|
Error error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ManagementState contains the latest state of a management connection
|
// ManagementState contains the latest state of a management connection
|
||||||
type ManagementState struct {
|
type ManagementState struct {
|
||||||
URL string
|
URL string
|
||||||
Connected bool
|
Connected bool
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RosenpassState contains the latest state of the Rosenpass configuration
|
||||||
|
type RosenpassState struct {
|
||||||
|
Enabled bool
|
||||||
|
Permissive bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NSGroupState represents the status of a DNS server group, including associated domains,
|
||||||
|
// whether it's enabled, and the last error message encountered during probing.
|
||||||
|
type NSGroupState struct {
|
||||||
|
ID string
|
||||||
|
Servers []string
|
||||||
|
Domains []string
|
||||||
|
Enabled bool
|
||||||
|
Error error
|
||||||
}
|
}
|
||||||
|
|
||||||
// FullStatus contains the full state held by the Status instance
|
// FullStatus contains the full state held by the Status instance
|
||||||
@@ -45,20 +110,29 @@ type FullStatus struct {
|
|||||||
ManagementState ManagementState
|
ManagementState ManagementState
|
||||||
SignalState SignalState
|
SignalState SignalState
|
||||||
LocalPeerState LocalPeerState
|
LocalPeerState LocalPeerState
|
||||||
|
RosenpassState RosenpassState
|
||||||
|
Relays []relay.ProbeResult
|
||||||
|
NSGroupStates []NSGroupState
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status holds a state of peers, signal and management connections
|
// Status holds a state of peers, signal, management connections and relays
|
||||||
type Status struct {
|
type Status struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
peers map[string]State
|
peers map[string]State
|
||||||
changeNotify map[string]chan struct{}
|
changeNotify map[string]chan struct{}
|
||||||
signalState bool
|
signalState bool
|
||||||
managementState bool
|
signalError error
|
||||||
localPeer LocalPeerState
|
managementState bool
|
||||||
offlinePeers []State
|
managementError error
|
||||||
mgmAddress string
|
relayStates []relay.ProbeResult
|
||||||
signalAddress string
|
localPeer LocalPeerState
|
||||||
notifier *notifier
|
offlinePeers []State
|
||||||
|
mgmAddress string
|
||||||
|
signalAddress string
|
||||||
|
notifier *notifier
|
||||||
|
rosenpassEnabled bool
|
||||||
|
rosenpassPermissive bool
|
||||||
|
nsGroupStates []NSGroupState
|
||||||
|
|
||||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||||
@@ -101,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
|||||||
PubKey: peerPubKey,
|
PubKey: peerPubKey,
|
||||||
ConnStatus: StatusDisconnected,
|
ConnStatus: StatusDisconnected,
|
||||||
FQDN: fqdn,
|
FQDN: fqdn,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
d.peerListChangedForNotification = true
|
d.peerListChangedForNotification = true
|
||||||
return nil
|
return nil
|
||||||
@@ -147,6 +222,10 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
peerState.IP = receivedState.IP
|
peerState.IP = receivedState.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if receivedState.GetRoutes() != nil {
|
||||||
|
peerState.SetRoutes(receivedState.GetRoutes())
|
||||||
|
}
|
||||||
|
|
||||||
skipNotification := shouldSkipNotify(receivedState, peerState)
|
skipNotification := shouldSkipNotify(receivedState, peerState)
|
||||||
|
|
||||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||||
@@ -156,6 +235,9 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
peerState.Relayed = receivedState.Relayed
|
peerState.Relayed = receivedState.Relayed
|
||||||
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
|
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
|
||||||
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
|
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
|
||||||
|
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
|
||||||
|
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
|
||||||
|
peerState.RosenpassEnabled = receivedState.RosenpassEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
d.peers[receivedState.PubKey] = peerState
|
d.peers[receivedState.PubKey] = peerState
|
||||||
@@ -174,6 +256,25 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
|
||||||
|
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState.LastWireguardHandshake = wgStats.LastHandshake
|
||||||
|
peerState.BytesRx = wgStats.RxBytes
|
||||||
|
peerState.BytesTx = wgStats.TxBytes
|
||||||
|
|
||||||
|
d.peers[pubKey] = peerState
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func shouldSkipNotify(received, curr State) bool {
|
func shouldSkipNotify(received, curr State) bool {
|
||||||
switch {
|
switch {
|
||||||
case received.ConnStatus == StatusConnecting:
|
case received.ConnStatus == StatusConnecting:
|
||||||
@@ -229,6 +330,13 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
|||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLocalPeerState returns the local peer state
|
||||||
|
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
return d.localPeer
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateLocalPeerState updates local peer status
|
// UpdateLocalPeerState updates local peer status
|
||||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
@@ -248,12 +356,13 @@ func (d *Status) CleanLocalPeerState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||||
func (d *Status) MarkManagementDisconnected() {
|
func (d *Status) MarkManagementDisconnected(err error) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
defer d.onConnectionChanged()
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.managementState = false
|
d.managementState = false
|
||||||
|
d.managementError = err
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkManagementConnected sets ManagementState to connected
|
// MarkManagementConnected sets ManagementState to connected
|
||||||
@@ -263,6 +372,7 @@ func (d *Status) MarkManagementConnected() {
|
|||||||
defer d.onConnectionChanged()
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.managementState = true
|
d.managementState = true
|
||||||
|
d.managementError = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSignalAddress update the address of the signal server
|
// UpdateSignalAddress update the address of the signal server
|
||||||
@@ -279,13 +389,22 @@ func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
|||||||
d.mgmAddress = mgmAddress
|
d.mgmAddress = mgmAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateRosenpass update the Rosenpass configuration
|
||||||
|
func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.rosenpassPermissive = rosenpassPermissive
|
||||||
|
d.rosenpassEnabled = rosenpassEnabled
|
||||||
|
}
|
||||||
|
|
||||||
// MarkSignalDisconnected sets SignalState to disconnected
|
// MarkSignalDisconnected sets SignalState to disconnected
|
||||||
func (d *Status) MarkSignalDisconnected() {
|
func (d *Status) MarkSignalDisconnected(err error) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
defer d.onConnectionChanged()
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.signalState = false
|
d.signalState = false
|
||||||
|
d.signalError = err
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkSignalConnected sets SignalState to connected
|
// MarkSignalConnected sets SignalState to connected
|
||||||
@@ -295,6 +414,83 @@ func (d *Status) MarkSignalConnected() {
|
|||||||
defer d.onConnectionChanged()
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.signalState = true
|
d.signalState = true
|
||||||
|
d.signalError = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.relayStates = relayResults
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.nsGroupStates = dnsStates
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) GetRosenpassState() RosenpassState {
|
||||||
|
return RosenpassState{
|
||||||
|
d.rosenpassEnabled,
|
||||||
|
d.rosenpassPermissive,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) GetManagementState() ManagementState {
|
||||||
|
return ManagementState{
|
||||||
|
d.mgmAddress,
|
||||||
|
d.managementState,
|
||||||
|
d.managementError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
||||||
|
if latency <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
peerState, ok := d.peers[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
peerState.Latency = latency
|
||||||
|
d.peers[pubKey] = peerState
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLoginRequired determines if a peer's login has expired.
|
||||||
|
func (d *Status) IsLoginRequired() bool {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
// if peer is connected to the management then login is not expired
|
||||||
|
if d.managementState {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s, ok := gstatus.FromError(d.managementError)
|
||||||
|
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) GetSignalState() SignalState {
|
||||||
|
return SignalState{
|
||||||
|
d.signalAddress,
|
||||||
|
d.signalState,
|
||||||
|
d.signalError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||||
|
return d.relayStates
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) GetDNSStates() []NSGroupState {
|
||||||
|
return d.nsGroupStates
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFullStatus gets full status
|
// GetFullStatus gets full status
|
||||||
@@ -303,15 +499,12 @@ func (d *Status) GetFullStatus() FullStatus {
|
|||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
fullStatus := FullStatus{
|
fullStatus := FullStatus{
|
||||||
ManagementState: ManagementState{
|
ManagementState: d.GetManagementState(),
|
||||||
d.mgmAddress,
|
SignalState: d.GetSignalState(),
|
||||||
d.managementState,
|
LocalPeerState: d.localPeer,
|
||||||
},
|
Relays: d.GetRelayStates(),
|
||||||
SignalState: SignalState{
|
RosenpassState: d.GetRosenpassState(),
|
||||||
d.signalAddress,
|
NSGroupStates: d.GetDNSStates(),
|
||||||
d.signalState,
|
|
||||||
},
|
|
||||||
LocalPeerState: d.localPeer,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, status := range d.peers {
|
for _, status := range d.peers {
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@@ -41,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -61,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -79,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -103,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
|
|||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: key,
|
PubKey: key,
|
||||||
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
status.peers[key] = peerState
|
status.peers[key] = peerState
|
||||||
@@ -152,9 +158,10 @@ func TestUpdateSignalState(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
connected bool
|
connected bool
|
||||||
want bool
|
want bool
|
||||||
|
err error
|
||||||
}{
|
}{
|
||||||
{"should mark as connected", true, true},
|
{"should mark as connected", true, true, nil},
|
||||||
{"should mark as disconnected", false, false},
|
{"should mark as disconnected", false, false, errors.New("test")},
|
||||||
}
|
}
|
||||||
|
|
||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
@@ -165,9 +172,10 @@ func TestUpdateSignalState(t *testing.T) {
|
|||||||
if test.connected {
|
if test.connected {
|
||||||
status.MarkSignalConnected()
|
status.MarkSignalConnected()
|
||||||
} else {
|
} else {
|
||||||
status.MarkSignalDisconnected()
|
status.MarkSignalDisconnected(test.err)
|
||||||
}
|
}
|
||||||
assert.Equal(t, test.want, status.signalState, "signal status should be equal")
|
assert.Equal(t, test.want, status.signalState, "signal status should be equal")
|
||||||
|
assert.Equal(t, test.err, status.signalError)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -178,9 +186,10 @@ func TestUpdateManagementState(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
connected bool
|
connected bool
|
||||||
want bool
|
want bool
|
||||||
|
err error
|
||||||
}{
|
}{
|
||||||
{"should mark as connected", true, true},
|
{"should mark as connected", true, true, nil},
|
||||||
{"should mark as disconnected", false, false},
|
{"should mark as disconnected", false, false, errors.New("test")},
|
||||||
}
|
}
|
||||||
|
|
||||||
status := NewRecorder(url)
|
status := NewRecorder(url)
|
||||||
@@ -190,9 +199,10 @@ func TestUpdateManagementState(t *testing.T) {
|
|||||||
if test.connected {
|
if test.connected {
|
||||||
status.MarkManagementConnected()
|
status.MarkManagementConnected()
|
||||||
} else {
|
} else {
|
||||||
status.MarkManagementDisconnected()
|
status.MarkManagementDisconnected(test.err)
|
||||||
}
|
}
|
||||||
assert.Equal(t, test.want, status.managementState, "signalState status should be equal")
|
assert.Equal(t, test.want, status.managementState, "signalState status should be equal")
|
||||||
|
assert.Equal(t, test.err, status.managementError)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
51
client/internal/probe.go
Normal file
51
client/internal/probe.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Probe allows to run on-demand callbacks from different code locations.
|
||||||
|
// Pass the probe to a receiving and a sending end. The receiving end starts listening
|
||||||
|
// to requests with Receive and executes a callback when the sending end requests it
|
||||||
|
// by calling Probe.
|
||||||
|
type Probe struct {
|
||||||
|
request chan struct{}
|
||||||
|
result chan bool
|
||||||
|
ready bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProbe returns a new initialized probe.
|
||||||
|
func NewProbe() *Probe {
|
||||||
|
return &Probe{
|
||||||
|
request: make(chan struct{}),
|
||||||
|
result: make(chan bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Probe requests the callback to be run and returns a bool indicating success.
|
||||||
|
// It always returns true as long as the receiver is not ready.
|
||||||
|
func (p *Probe) Probe() bool {
|
||||||
|
if !p.ready {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
p.request <- struct{}{}
|
||||||
|
return <-p.result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive starts listening for probe requests. On such a request it runs the supplied
|
||||||
|
// callback func which must return a bool indicating success.
|
||||||
|
// Blocks until the passed context is cancelled.
|
||||||
|
func (p *Probe) Receive(ctx context.Context, callback func() bool) {
|
||||||
|
p.ready = true
|
||||||
|
defer func() {
|
||||||
|
p.ready = false
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-p.request:
|
||||||
|
p.result <- callback()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
187
client/internal/relay/relay.go
Normal file
187
client/internal/relay/relay.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/stun/v2"
|
||||||
|
"github.com/pion/turn/v3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProbeResult holds the info about the result of a relay probe request
|
||||||
|
type ProbeResult struct {
|
||||||
|
URI *stun.URI
|
||||||
|
Err error
|
||||||
|
Addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProbeSTUN tries binding to the given STUN uri and acquiring an address
|
||||||
|
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||||
|
defer func() {
|
||||||
|
if probeErr != nil {
|
||||||
|
log.Debugf("stun probe error from %s: %s", uri, probeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
net, err := stdnet.NewNet(nil)
|
||||||
|
if err != nil {
|
||||||
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := stun.DialURI(uri, &stun.DialConfig{
|
||||||
|
Net: net,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
probeErr = fmt.Errorf("dial: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := client.Close(); err != nil && probeErr == nil {
|
||||||
|
probeErr = fmt.Errorf("close: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
if err = client.Start(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
||||||
|
if res.Error != nil {
|
||||||
|
probeErr = fmt.Errorf("request: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var xorAddr stun.XORMappedAddress
|
||||||
|
if getErr := xorAddr.GetFrom(res.Message); getErr != nil {
|
||||||
|
probeErr = fmt.Errorf("get xor addr: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("stun probe received address from %s: %s", uri, xorAddr)
|
||||||
|
addr = xorAddr.String()
|
||||||
|
|
||||||
|
done <- struct{}{}
|
||||||
|
}); err != nil {
|
||||||
|
probeErr = fmt.Errorf("client: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
probeErr = fmt.Errorf("stun request: %w", ctx.Err())
|
||||||
|
return
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProbeTURN tries allocating a session from the given TURN URI
|
||||||
|
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
|
||||||
|
defer func() {
|
||||||
|
if probeErr != nil {
|
||||||
|
log.Debugf("turn probe error from %s: %s", uri, probeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
turnServerAddr := fmt.Sprintf("%s:%d", uri.Host, uri.Port)
|
||||||
|
|
||||||
|
var conn net.PacketConn
|
||||||
|
switch uri.Proto {
|
||||||
|
case stun.ProtoTypeUDP:
|
||||||
|
var err error
|
||||||
|
conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "")
|
||||||
|
if err != nil {
|
||||||
|
probeErr = fmt.Errorf("listen: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case stun.ProtoTypeTCP:
|
||||||
|
tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr)
|
||||||
|
if err != nil {
|
||||||
|
probeErr = fmt.Errorf("dial: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn = turn.NewSTUNConn(tcpConn)
|
||||||
|
default:
|
||||||
|
probeErr = fmt.Errorf("conn: unknown proto: %s", uri.Proto)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil && probeErr == nil {
|
||||||
|
probeErr = fmt.Errorf("conn close: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
net, err := stdnet.NewNet(nil)
|
||||||
|
if err != nil {
|
||||||
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg := &turn.ClientConfig{
|
||||||
|
STUNServerAddr: turnServerAddr,
|
||||||
|
TURNServerAddr: turnServerAddr,
|
||||||
|
Conn: conn,
|
||||||
|
Username: uri.Username,
|
||||||
|
Password: uri.Password,
|
||||||
|
Net: net,
|
||||||
|
}
|
||||||
|
client, err := turn.NewClient(cfg)
|
||||||
|
if err != nil {
|
||||||
|
probeErr = fmt.Errorf("create client: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
if err := client.Listen(); err != nil {
|
||||||
|
probeErr = fmt.Errorf("client listen: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayConn, err := client.Allocate()
|
||||||
|
if err != nil {
|
||||||
|
probeErr = fmt.Errorf("allocate: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := relayConn.Close(); err != nil && probeErr == nil {
|
||||||
|
probeErr = fmt.Errorf("close relay conn: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Debugf("turn probe relay address from %s: %s", uri, relayConn.LocalAddr())
|
||||||
|
|
||||||
|
return relayConn.LocalAddr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProbeAll probes all given servers asynchronously and returns the results
|
||||||
|
func ProbeAll(
|
||||||
|
ctx context.Context,
|
||||||
|
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
|
||||||
|
relays []*stun.URI,
|
||||||
|
) []ProbeResult {
|
||||||
|
results := make([]ProbeResult, len(relays))
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i, uri := range relays {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(res *ProbeResult, stunURI *stun.URI) {
|
||||||
|
defer wg.Done()
|
||||||
|
res.URI = stunURI
|
||||||
|
res.Addr, res.Err = fn(ctx, stunURI)
|
||||||
|
}(&results[i], uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -18,6 +19,7 @@ type routerPeerStatus struct {
|
|||||||
connected bool
|
connected bool
|
||||||
relayed bool
|
relayed bool
|
||||||
direct bool
|
direct bool
|
||||||
|
latency time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type routesUpdate struct {
|
type routesUpdate struct {
|
||||||
@@ -41,6 +43,7 @@ type clientNetwork struct {
|
|||||||
|
|
||||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
@@ -67,14 +70,29 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|||||||
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
||||||
relayed: peerStatus.Relayed,
|
relayed: peerStatus.Relayed,
|
||||||
direct: peerStatus.Direct,
|
direct: peerStatus.Direct,
|
||||||
|
latency: peerStatus.Latency,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return routePeerStatuses
|
return routePeerStatuses
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getBestRouteFromStatuses determines the most optimal route from the available routes
|
||||||
|
// within a clientNetwork, taking into account peer connection status, route metrics, and
|
||||||
|
// preference for non-relayed and direct connections.
|
||||||
|
//
|
||||||
|
// It follows these prioritization rules:
|
||||||
|
// * Connected peers: Only routes with connected peers are considered.
|
||||||
|
// * Metric: Routes with lower metrics (better) are prioritized.
|
||||||
|
// * Non-relayed: Routes without relays are preferred.
|
||||||
|
// * Direct connections: Routes with direct peer connections are favored.
|
||||||
|
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||||
|
// * Latency: Routes with lower latency are prioritized.
|
||||||
|
//
|
||||||
|
// It returns the ID of the selected optimal route.
|
||||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
||||||
chosen := ""
|
chosen := ""
|
||||||
chosenScore := 0
|
chosenScore := float64(0)
|
||||||
|
currScore := float64(0)
|
||||||
|
|
||||||
currID := ""
|
currID := ""
|
||||||
if c.chosenRoute != nil {
|
if c.chosenRoute != nil {
|
||||||
@@ -82,7 +100,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range c.routes {
|
for _, r := range c.routes {
|
||||||
tempScore := 0
|
tempScore := float64(0)
|
||||||
peerStatus, found := routePeerStatuses[r.ID]
|
peerStatus, found := routePeerStatuses[r.ID]
|
||||||
if !found || !peerStatus.connected {
|
if !found || !peerStatus.connected {
|
||||||
continue
|
continue
|
||||||
@@ -90,9 +108,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
|
|
||||||
if r.Metric < route.MaxMetric {
|
if r.Metric < route.MaxMetric {
|
||||||
metricDiff := route.MaxMetric - r.Metric
|
metricDiff := route.MaxMetric - r.Metric
|
||||||
tempScore = metricDiff * 10
|
tempScore = float64(metricDiff) * 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
||||||
|
latency := time.Second
|
||||||
|
if peerStatus.latency != 0 {
|
||||||
|
latency = peerStatus.latency
|
||||||
|
} else {
|
||||||
|
log.Warnf("peer %s has 0 latency", r.Peer)
|
||||||
|
}
|
||||||
|
tempScore += 1 - latency.Seconds()
|
||||||
|
|
||||||
if !peerStatus.relayed {
|
if !peerStatus.relayed {
|
||||||
tempScore++
|
tempScore++
|
||||||
}
|
}
|
||||||
@@ -101,7 +128,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
tempScore++
|
tempScore++
|
||||||
}
|
}
|
||||||
|
|
||||||
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) {
|
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
|
||||||
chosen = r.ID
|
chosen = r.ID
|
||||||
chosenScore = tempScore
|
chosenScore = tempScore
|
||||||
}
|
}
|
||||||
@@ -110,18 +137,26 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
|||||||
chosen = r.ID
|
chosen = r.ID
|
||||||
chosenScore = tempScore
|
chosenScore = tempScore
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.ID == currID {
|
||||||
|
currScore = tempScore
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if chosen == "" {
|
switch {
|
||||||
|
case chosen == "":
|
||||||
var peers []string
|
var peers []string
|
||||||
for _, r := range c.routes {
|
for _, r := range c.routes {
|
||||||
peers = append(peers, r.Peer)
|
peers = append(peers, r.Peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
||||||
|
case chosen != currID:
|
||||||
} else if chosen != currID {
|
if currScore != 0 && currScore < chosenScore+0.1 {
|
||||||
log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
|
return currID
|
||||||
|
} else {
|
||||||
|
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return chosen
|
return chosen
|
||||||
@@ -158,15 +193,21 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
|||||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("get peer state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state.DeleteRoute(c.network.String())
|
||||||
|
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||||
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if state.ConnStatus != peer.StatusConnected {
|
if state.ConnStatus != peer.StatusConnected {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v",
|
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
|
||||||
c.network, c.chosenRoute.Peer, err)
|
c.network, c.chosenRoute.Peer, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -174,30 +215,26 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
|||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||||
if c.chosenRoute != nil {
|
if c.chosenRoute != nil {
|
||||||
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
|
|
||||||
if err != nil {
|
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||||
return fmt.Errorf("couldn't remove route %s from system, err: %v",
|
return fmt.Errorf("remove route: %v", err)
|
||||||
c.network, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||||
|
|
||||||
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||||
|
|
||||||
|
// If no route is chosen, remove the route from the peer and system
|
||||||
if chosen == "" {
|
if chosen == "" {
|
||||||
err = c.removeRouteFromPeerAndSystem()
|
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("remove route from peer and system: %v", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.chosenRoute = nil
|
c.chosenRoute = nil
|
||||||
@@ -205,6 +242,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the chosen route is the same as the current route, do nothing
|
||||||
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
||||||
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
||||||
return nil
|
return nil
|
||||||
@@ -212,21 +250,31 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.chosenRoute != nil {
|
if c.chosenRoute != nil {
|
||||||
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
// If a previous route exists, remove it from the peer
|
||||||
if err != nil {
|
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||||
return err
|
return fmt.Errorf("remove route from peer: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
|
// otherwise add the route to the system
|
||||||
if err != nil {
|
if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil {
|
||||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.chosenRoute = c.routes[chosen]
|
c.chosenRoute = c.routes[chosen]
|
||||||
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
|
|
||||||
|
state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Errorf("Failed to get peer state: %v", err)
|
||||||
|
} else {
|
||||||
|
state.AddRoute(c.network.String())
|
||||||
|
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||||
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
|
||||||
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
||||||
c.network, c.chosenRoute.Peer, err)
|
c.network, c.chosenRoute.Peer, err)
|
||||||
}
|
}
|
||||||
@@ -267,21 +315,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
log.Debugf("stopping watcher for network %s", c.network)
|
log.Debugf("stopping watcher for network %s", c.network)
|
||||||
err := c.removeRouteFromPeerAndSystem()
|
err := c.removeRouteFromPeerAndSystem()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-c.peerStateUpdate:
|
case <-c.peerStateUpdate:
|
||||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
|
||||||
}
|
}
|
||||||
case update := <-c.routeUpdate:
|
case update := <-c.routeUpdate:
|
||||||
if update.updateSerial < c.updateSerial {
|
if update.updateSerial < c.updateSerial {
|
||||||
log.Warnf("received a routes update with smaller serial number, ignoring it")
|
log.Warnf("Received a routes update with smaller serial number, ignoring it")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("received a new client network route update for %s", c.network)
|
log.Debugf("Received a new client network route update for %s", c.network)
|
||||||
|
|
||||||
c.handleUpdate(update)
|
c.handleUpdate(update)
|
||||||
|
|
||||||
@@ -289,7 +337,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
|
|
||||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.startPeersStatusChangeWatcher()
|
c.startPeersStatusChangeWatcher()
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package routemanager
|
|||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -13,7 +14,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
statuses map[string]routerPeerStatus
|
statuses map[string]routerPeerStatus
|
||||||
expectedRouteID string
|
expectedRouteID string
|
||||||
currentRoute *route.Route
|
currentRoute string
|
||||||
existingRoutes map[string]*route.Route
|
existingRoutes map[string]*route.Route
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -32,7 +33,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer1",
|
Peer: "peer1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -51,7 +52,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer1",
|
Peer: "peer1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -70,7 +71,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer1",
|
Peer: "peer1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -89,7 +90,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer1",
|
Peer: "peer1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "",
|
expectedRouteID: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -118,7 +119,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer2",
|
Peer: "peer2",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -147,7 +148,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer2",
|
Peer: "peer2",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -176,18 +177,141 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
Peer: "peer2",
|
Peer: "peer2",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentRoute: nil,
|
currentRoute: "",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "multiple connected peers with different latencies",
|
||||||
|
statuses: map[string]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
latency: 300 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
latency: 10 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[string]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "",
|
||||||
|
expectedRouteID: "route2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should ignore routes with latency 0",
|
||||||
|
statuses: map[string]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
latency: 10 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[string]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "",
|
||||||
|
expectedRouteID: "route2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "current route with similar score and similar but slightly worse latency should not change",
|
||||||
|
statuses: map[string]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 12 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 10 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[string]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "route1",
|
||||||
|
expectedRouteID: "route1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "current chosen route doesn't exist anymore",
|
||||||
|
statuses: map[string]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 20 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
direct: true,
|
||||||
|
latency: 10 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[string]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "routeDoesntExistAnymore",
|
||||||
|
expectedRouteID: "route2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
currentRoute := &route.Route{
|
||||||
|
ID: "routeDoesntExistAnymore",
|
||||||
|
}
|
||||||
|
if tc.currentRoute != "" {
|
||||||
|
currentRoute = tc.existingRoutes[tc.currentRoute]
|
||||||
|
}
|
||||||
|
|
||||||
// create new clientNetwork
|
// create new clientNetwork
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
network: netip.MustParsePrefix("192.168.0.0/24"),
|
network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
routes: tc.existingRoutes,
|
routes: tc.existingRoutes,
|
||||||
chosenRoute: tc.currentRoute,
|
chosenRoute: currentRoute,
|
||||||
}
|
}
|
||||||
|
|
||||||
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -12,11 +16,18 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
|
|
||||||
|
// nolint:unused
|
||||||
|
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
|
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
@@ -56,9 +67,31 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Init sets up the routing
|
||||||
|
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
if nbnet.CustomRoutingDisabled() {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cleanupRouting(); err != nil {
|
||||||
|
log.Warnf("Failed cleaning up routing: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgmtAddress := m.statusRecorder.GetManagementState().URL
|
||||||
|
signalAddress := m.statusRecorder.GetSignalState().URL
|
||||||
|
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
|
||||||
|
|
||||||
|
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||||
|
}
|
||||||
|
log.Info("Routing setup complete")
|
||||||
|
return beforePeerHook, afterPeerHook, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||||
var err error
|
var err error
|
||||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall)
|
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -71,10 +104,19 @@ func (m *DefaultManager) Stop() {
|
|||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
m.serverRouter.cleanUp()
|
m.serverRouter.cleanUp()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !nbnet.CustomRoutingDisabled() {
|
||||||
|
if err := cleanupRouting(); err != nil {
|
||||||
|
log.Errorf("Error cleaning up routing: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Info("Routing cleanup complete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.ctx = nil
|
m.ctx = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
@@ -92,7 +134,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("update routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,11 +199,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
|
|||||||
for _, newRoute := range newRoutes {
|
for _, newRoute := range newRoutes {
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
networkID := route.GetHAUniqueID(newRoute)
|
||||||
if !ownNetworkIDs[networkID] {
|
if !ownNetworkIDs[networkID] {
|
||||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
if !isPrefixSupported(newRoute.Network) {
|
||||||
// we skip this route management
|
|
||||||
if newRoute.Network.Bits() < minRangeBits {
|
|
||||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route",
|
|
||||||
version.NetbirdVersion(), newRoute.Network)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||||
@@ -179,3 +217,40 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
|
|||||||
}
|
}
|
||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||||
|
if !nbnet.CustomRoutingDisabled() {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux", "windows", "darwin":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||||
|
// we skip this prefix management
|
||||||
|
if prefix.Bits() <= minRangeBits {
|
||||||
|
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
||||||
|
version.NetbirdVersion(), prefix)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs.
|
||||||
|
func resolveURLsToIPs(urls []string) []net.IP {
|
||||||
|
var ips []net.IP
|
||||||
|
for _, rawurl := range urls {
|
||||||
|
u, err := url.Parse(rawurl)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to parse url %s: %v", rawurl, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ipAddrs, err := net.LookupIP(u.Hostname())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ips = append(ips, ipAddrs...)
|
||||||
|
}
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,13 +28,14 @@ const remotePeerKey2 = "remote1"
|
|||||||
|
|
||||||
func TestManagerUpdateRoutes(t *testing.T) {
|
func TestManagerUpdateRoutes(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
inputInitRoutes []*route.Route
|
inputInitRoutes []*route.Route
|
||||||
inputRoutes []*route.Route
|
inputRoutes []*route.Route
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
removeSrvRouter bool
|
removeSrvRouter bool
|
||||||
serverRoutesExpected int
|
serverRoutesExpected int
|
||||||
clientNetworkWatchersExpected int
|
clientNetworkWatchersExpected int
|
||||||
|
clientNetworkWatchersExpectedAllowed int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Should create 2 client networks",
|
name: "Should create 2 client networks",
|
||||||
@@ -200,8 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
clientNetworkWatchersExpected: 0,
|
clientNetworkWatchersExpected: 0,
|
||||||
|
clientNetworkWatchersExpectedAllowed: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Remove 1 Client Route",
|
name: "Remove 1 Client Route",
|
||||||
@@ -415,6 +417,10 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
statusRecorder := peer.NewRecorder("https://mgm")
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||||
|
|
||||||
|
_, _, err = routeManager.Init()
|
||||||
|
|
||||||
|
require.NoError(t, err, "should init route manager")
|
||||||
defer routeManager.Stop()
|
defer routeManager.Stop()
|
||||||
|
|
||||||
if testCase.removeSrvRouter {
|
if testCase.removeSrvRouter {
|
||||||
@@ -429,7 +435,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||||
require.NoError(t, err, "should update routes")
|
require.NoError(t, err, "should update routes")
|
||||||
|
|
||||||
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||||
|
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
|
||||||
|
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
|
||||||
|
}
|
||||||
|
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||||
|
|
||||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||||
sr := routeManager.serverRouter.(*defaultServerRouter)
|
sr := routeManager.serverRouter.(*defaultServerRouter)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -16,6 +17,10 @@ type MockManager struct {
|
|||||||
StopFunc func()
|
StopFunc func()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||||
func (m *MockManager) InitialRouteRange() []string {
|
func (m *MockManager) InitialRouteRange() []string {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
126
client/internal/routemanager/routemanager.go
Normal file
126
client/internal/routemanager/routemanager.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ref struct {
|
||||||
|
count int
|
||||||
|
nexthop netip.Addr
|
||||||
|
intf string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouteManager struct {
|
||||||
|
// refCountMap keeps track of the reference ref for prefixes
|
||||||
|
refCountMap map[netip.Prefix]ref
|
||||||
|
// prefixMap keeps track of the prefixes associated with a connection ID for removal
|
||||||
|
prefixMap map[nbnet.ConnectionID][]netip.Prefix
|
||||||
|
addRoute AddRouteFunc
|
||||||
|
removeRoute RemoveRouteFunc
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error)
|
||||||
|
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error
|
||||||
|
|
||||||
|
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
|
||||||
|
// TODO: read initial routing table into refCountMap
|
||||||
|
return &RouteManager{
|
||||||
|
refCountMap: map[netip.Prefix]ref{},
|
||||||
|
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
|
||||||
|
addRoute: addRoute,
|
||||||
|
removeRoute: removeRoute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||||
|
rm.mutex.Lock()
|
||||||
|
defer rm.mutex.Unlock()
|
||||||
|
|
||||||
|
ref := rm.refCountMap[prefix]
|
||||||
|
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
|
||||||
|
|
||||||
|
// Add route to the system, only if it's a new prefix
|
||||||
|
if ref.count == 0 {
|
||||||
|
log.Debugf("Adding route for prefix %s", prefix)
|
||||||
|
nexthop, intf, err := rm.addRoute(prefix)
|
||||||
|
if errors.Is(err, ErrRouteNotFound) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrRouteNotAllowed) {
|
||||||
|
log.Debugf("Adding route for prefix %s: %s", prefix, err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
ref.nexthop = nexthop
|
||||||
|
ref.intf = intf
|
||||||
|
}
|
||||||
|
|
||||||
|
ref.count++
|
||||||
|
rm.refCountMap[prefix] = ref
|
||||||
|
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
|
||||||
|
rm.mutex.Lock()
|
||||||
|
defer rm.mutex.Unlock()
|
||||||
|
|
||||||
|
prefixes, ok := rm.prefixMap[connID]
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("No prefixes found for connection ID %s", connID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
ref := rm.refCountMap[prefix]
|
||||||
|
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
|
||||||
|
if ref.count == 1 {
|
||||||
|
log.Debugf("Removing route for prefix %s", prefix)
|
||||||
|
// TODO: don't fail if the route is not found
|
||||||
|
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(rm.refCountMap, prefix)
|
||||||
|
} else {
|
||||||
|
ref.count--
|
||||||
|
rm.refCountMap[prefix] = ref
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(rm.prefixMap, connID)
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush removes all references and routes from the system
|
||||||
|
func (rm *RouteManager) Flush() error {
|
||||||
|
rm.mutex.Lock()
|
||||||
|
defer rm.mutex.Unlock()
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for prefix := range rm.refCountMap {
|
||||||
|
log.Debugf("Removing route for prefix %s", prefix)
|
||||||
|
ref := rm.refCountMap[prefix]
|
||||||
|
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rm.refCountMap = map[netip.Prefix]ref{}
|
||||||
|
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
@@ -7,9 +7,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) {
|
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
|
||||||
return nil, fmt.Errorf("server route not supported on this os")
|
return nil, fmt.Errorf("server route not supported on this os")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,30 +4,34 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type defaultServerRouter struct {
|
type defaultServerRouter struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
routes map[string]*route.Route
|
routes map[string]*route.Route
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) {
|
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
||||||
return &defaultServerRouter{
|
return &defaultServerRouter{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
routes: make(map[string]*route.Route),
|
routes: make(map[string]*route.Route),
|
||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,7 +49,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
|||||||
oldRoute := m.routes[routeID]
|
oldRoute := m.routes[routeID]
|
||||||
err := m.removeFromServerNetwork(oldRoute)
|
err := m.removeFromServerNetwork(oldRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
|
||||||
oldRoute.ID, oldRoute.Network, err)
|
oldRoute.ID, oldRoute.Network, err)
|
||||||
}
|
}
|
||||||
delete(m.routes, routeID)
|
delete(m.routes, routeID)
|
||||||
@@ -59,7 +63,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
|||||||
|
|
||||||
err := m.addToServerNetwork(newRoute)
|
err := m.addToServerNetwork(newRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m.routes[id] = newRoute
|
m.routes[id] = newRoute
|
||||||
@@ -78,16 +82,28 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
|||||||
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("not removing from server network because context is done")
|
log.Infof("Not removing from server network because context is done")
|
||||||
return m.ctx.Err()
|
return m.ctx.Err()
|
||||||
default:
|
default:
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
|
||||||
|
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("parse prefix: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("remove routing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
delete(m.routes, route.ID)
|
delete(m.routes, route.ID)
|
||||||
|
|
||||||
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
|
delete(state.Routes, route.Network.String())
|
||||||
|
m.statusRecorder.UpdateLocalPeerState(state)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -95,16 +111,31 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
|
|||||||
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("not adding to server network because context is done")
|
log.Infof("Not adding to server network because context is done")
|
||||||
return m.ctx.Err()
|
return m.ctx.Err()
|
||||||
default:
|
default:
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
|
||||||
|
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("parse prefix: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = m.firewall.InsertRoutingRules(routerPair)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert routing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
m.routes[route.ID] = route
|
m.routes[route.ID] = route
|
||||||
|
|
||||||
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
|
if state.Routes == nil {
|
||||||
|
state.Routes = map[string]struct{}{}
|
||||||
|
}
|
||||||
|
state.Routes[route.Network.String()] = struct{}{}
|
||||||
|
m.statusRecorder.UpdateLocalPeerState(state)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -113,19 +144,33 @@ func (m *defaultServerRouter) cleanUp() {
|
|||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
for _, r := range m.routes {
|
for _, r := range m.routes {
|
||||||
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r))
|
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to remove clean up route: %s", r.ID)
|
log.Errorf("Failed to convert route to router pair: %v", err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to remove cleanup route: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
|
state.Routes = nil
|
||||||
|
m.statusRecorder.UpdateLocalPeerState(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair {
|
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
|
||||||
parsed := netip.MustParsePrefix(source).Masked()
|
parsed, err := netip.ParsePrefix(source)
|
||||||
|
if err != nil {
|
||||||
|
return firewall.RouterPair{}, err
|
||||||
|
}
|
||||||
return firewall.RouterPair{
|
return firewall.RouterPair{
|
||||||
ID: route.ID,
|
ID: route.ID,
|
||||||
Source: parsed.String(),
|
Source: parsed.String(),
|
||||||
Destination: route.Network.Masked().String(),
|
Destination: route.Network.Masked().String(),
|
||||||
Masquerade: route.Masquerade,
|
Masquerade: route.Masquerade,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
428
client/internal/routemanager/systemops.go
Normal file
428
client/internal/routemanager/systemops.go
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/libp2p/go-netroute"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||||
|
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||||
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||||
|
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||||
|
|
||||||
|
var ErrRouteNotFound = errors.New("route not found")
|
||||||
|
var ErrRouteNotAllowed = errors.New("route not allowed")
|
||||||
|
|
||||||
|
// TODO: fix: for default our wg address now appears as the default gw
|
||||||
|
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||||
|
addr := netip.IPv4Unspecified()
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
addr = netip.IPv6Unspecified()
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultGateway, _, err := getNextHop(addr)
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
|
return fmt.Errorf("get existing route gateway: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !prefix.Contains(defaultGateway) {
|
||||||
|
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
||||||
|
if defaultGateway.Is6() {
|
||||||
|
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := existsInRouteTable(gatewayPrefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitIntf string
|
||||||
|
gatewayHop, intf, err := getNextHop(defaultGateway)
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
|
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||||
|
}
|
||||||
|
if intf != nil {
|
||||||
|
exitIntf = intf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||||
|
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||||
|
r, err := netroute.New()
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
||||||
|
}
|
||||||
|
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get route for %s: %v", ip, err)
|
||||||
|
return netip.Addr{}, nil, ErrRouteNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||||
|
if gateway == nil {
|
||||||
|
if preferredSrc == nil {
|
||||||
|
return netip.Addr{}, nil, ErrRouteNotFound
|
||||||
|
}
|
||||||
|
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
||||||
|
|
||||||
|
addr, err := ipToAddr(preferredSrc, intf)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
|
||||||
|
}
|
||||||
|
return addr.Unmap(), intf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := ipToAddr(gateway, intf)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr, intf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
||||||
|
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||||
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
|
||||||
|
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
addr = addr.WithZone(strconv.Itoa(intf.Index))
|
||||||
|
} else {
|
||||||
|
addr = addr.WithZone(intf.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr.Unmap(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||||
|
routes, err := getRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get routes from table: %w", err)
|
||||||
|
}
|
||||||
|
for _, tableRoute := range routes {
|
||||||
|
if tableRoute == prefix {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||||
|
routes, err := getRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get routes from table: %w", err)
|
||||||
|
}
|
||||||
|
for _, tableRoute := range routes {
|
||||||
|
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||||
|
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||||
|
func addRouteToNonVPNIntf(
|
||||||
|
prefix netip.Prefix,
|
||||||
|
vpnIntf *iface.WGIface,
|
||||||
|
initialNextHop netip.Addr,
|
||||||
|
initialIntf *net.Interface,
|
||||||
|
) (netip.Addr, string, error) {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
switch {
|
||||||
|
case addr.IsLoopback(),
|
||||||
|
addr.IsLinkLocalUnicast(),
|
||||||
|
addr.IsLinkLocalMulticast(),
|
||||||
|
addr.IsInterfaceLocalMulticast(),
|
||||||
|
addr.IsUnspecified(),
|
||||||
|
addr.IsMulticast():
|
||||||
|
|
||||||
|
return netip.Addr{}, "", ErrRouteNotAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||||
|
nexthop, intf, err := getNextHop(addr)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
||||||
|
exitNextHop := nexthop
|
||||||
|
var exitIntf string
|
||||||
|
if intf != nil {
|
||||||
|
exitIntf = intf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||||
|
}
|
||||||
|
|
||||||
|
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||||
|
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
|
||||||
|
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
||||||
|
exitNextHop = initialNextHop
|
||||||
|
if initialIntf != nil {
|
||||||
|
exitIntf = initialIntf.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
||||||
|
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return exitNextHop, exitIntf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||||
|
// in two /1 prefixes to avoid replacing the existing default route
|
||||||
|
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if prefix == defaultv4 {
|
||||||
|
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||||
|
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove once IPv6 is supported on the interface
|
||||||
|
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||||
|
}
|
||||||
|
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
} else if prefix == defaultv6 {
|
||||||
|
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||||
|
}
|
||||||
|
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return addNonExistingRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||||
|
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
ok, err := existsInRouteTable(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("exists in route table: %w", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = isSubRange(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sub range: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
err := addRouteForCurrentDefaultGateway(prefix)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addToRouteTable(prefix, netip.Addr{}, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||||
|
// it will remove the split /1 prefixes
|
||||||
|
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if prefix == defaultv4 {
|
||||||
|
var result *multierror.Error
|
||||||
|
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove once IPv6 is supported on the interface
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
} else if prefix == defaultv6 {
|
||||||
|
var result *multierror.Error
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
||||||
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("parse IP address: %s", ip)
|
||||||
|
}
|
||||||
|
addr = addr.Unmap()
|
||||||
|
|
||||||
|
var prefixLength int
|
||||||
|
switch {
|
||||||
|
case addr.Is4():
|
||||||
|
prefixLength = 32
|
||||||
|
case addr.Is6():
|
||||||
|
prefixLength = 128
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||||
|
return &prefix, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
|
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||||
|
}
|
||||||
|
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
|
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
*routeManager = NewRouteManager(
|
||||||
|
func(prefix netip.Prefix) (netip.Addr, string, error) {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
nexthop, intf := initialNextHopV4, initialIntfV4
|
||||||
|
if addr.Is6() {
|
||||||
|
nexthop, intf = initialNextHopV6, initialIntfV6
|
||||||
|
}
|
||||||
|
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
||||||
|
},
|
||||||
|
removeFromRouteTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
return setupHooks(*routeManager, initAddresses)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
||||||
|
if routeManager == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Remove hooks selectively
|
||||||
|
nbnet.RemoveDialerHooks()
|
||||||
|
nbnet.RemoveListenerHooks()
|
||||||
|
|
||||||
|
if err := routeManager.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush route manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||||
|
prefix, err := getPrefixFromIP(ip)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
||||||
|
return fmt.Errorf("adding route reference: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
afterHook := func(connID nbnet.ConnectionID) error {
|
||||||
|
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
||||||
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range initAddresses {
|
||||||
|
if err := beforeHook("init", ip); err != nil {
|
||||||
|
log.Errorf("Failed to add route reference: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for _, ip := range resolvedIPs {
|
||||||
|
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||||
|
}
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
})
|
||||||
|
|
||||||
|
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||||
|
return afterHook(connID)
|
||||||
|
})
|
||||||
|
|
||||||
|
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||||
|
return beforeHook(connID, ip.IP)
|
||||||
|
})
|
||||||
|
|
||||||
|
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||||
|
return afterHook(connID)
|
||||||
|
})
|
||||||
|
|
||||||
|
return beforeHook, afterHook, nil
|
||||||
|
}
|
||||||
@@ -1,13 +1,33 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRouting() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
func enableIPForwarding() error {
|
||||||
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addVPNRoute(netip.Prefix, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeVPNRoute(netip.Prefix, string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||||
// +build darwin dragonfly freebsd netbsd openbsd
|
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
@@ -9,6 +8,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/route"
|
"golang.org/x/net/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,16 +52,24 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(m.Addrs) < 3 {
|
||||||
|
log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
addr, ok := toNetIPAddr(m.Addrs[0])
|
addr, ok := toNetIPAddr(m.Addrs[0])
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
mask, ok := toNetIPMASK(m.Addrs[2])
|
cidr := 32
|
||||||
if !ok {
|
if mask := m.Addrs[2]; mask != nil {
|
||||||
continue
|
cidr, ok = toCIDR(mask)
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("Unexpected RIB message Addrs[2]: %v", mask)
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cidr, _ := mask.Size()
|
|
||||||
|
|
||||||
routePrefix := netip.PrefixFrom(addr, cidr)
|
routePrefix := netip.PrefixFrom(addr, cidr)
|
||||||
if routePrefix.IsValid() {
|
if routePrefix.IsValid() {
|
||||||
@@ -74,20 +82,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
|
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
|
||||||
switch t := a.(type) {
|
switch t := a.(type) {
|
||||||
case *route.Inet4Addr:
|
case *route.Inet4Addr:
|
||||||
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
return netip.AddrFrom4(t.IP), true
|
||||||
addr := netip.MustParseAddr(ip.String())
|
|
||||||
return addr, true
|
|
||||||
default:
|
default:
|
||||||
return netip.Addr{}, false
|
return netip.Addr{}, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toNetIPMASK(a route.Addr) (net.IPMask, bool) {
|
func toCIDR(a route.Addr) (int, bool) {
|
||||||
switch t := a.(type) {
|
switch t := a.(type) {
|
||||||
case *route.Inet4Addr:
|
case *route.Inet4Addr:
|
||||||
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
||||||
return mask, true
|
cidr, _ := mask.Size()
|
||||||
|
return cidr, true
|
||||||
default:
|
default:
|
||||||
return nil, false
|
return 0, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
89
client/internal/routemanager/systemops_darwin.go
Normal file
89
client/internal/routemanager/systemops_darwin.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
var routeManager *RouteManager
|
||||||
|
|
||||||
|
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRouting() error {
|
||||||
|
return cleanupRoutingWithRouteManager(routeManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
return routeCmd("add", prefix, nexthop, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
return routeCmd("delete", prefix, nexthop, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
inet := "-inet"
|
||||||
|
network := prefix.String()
|
||||||
|
if prefix.IsSingleIP() {
|
||||||
|
network = prefix.Addr().String()
|
||||||
|
}
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
inet = "-inet6"
|
||||||
|
// Special case for IPv6 split default route, pointing to the wg interface fails
|
||||||
|
// TODO: Remove once we have IPv6 support on the interface
|
||||||
|
if prefix.Bits() == 1 {
|
||||||
|
intf = "lo0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{"-n", action, inet, network}
|
||||||
|
if nexthop.IsValid() {
|
||||||
|
args = append(args, nexthop.Unmap().String())
|
||||||
|
} else if intf != "" {
|
||||||
|
args = append(args, "-interface", intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := retryRouteCmd(args); err != nil {
|
||||||
|
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func retryRouteCmd(args []string) error {
|
||||||
|
operation := func() error {
|
||||||
|
out, err := exec.Command("route", args...).CombinedOutput()
|
||||||
|
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||||
|
// https://github.com/golang/go/issues/45736
|
||||||
|
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
|
||||||
|
return err
|
||||||
|
} else if err != nil {
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expBackOff := backoff.NewExponentialBackOff()
|
||||||
|
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||||
|
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||||
|
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||||
|
|
||||||
|
err := backoff.Retry(operation, expBackOff)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route cmd retry failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
138
client/internal/routemanager/systemops_darwin_test.go
Normal file
138
client/internal/routemanager/systemops_darwin_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var expectedVPNint = "utun100"
|
||||||
|
var expectedExternalInt = "lo0"
|
||||||
|
var expectedInternalInt = "lo0"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
testCases = append(testCases, []testCase{
|
||||||
|
{
|
||||||
|
name: "To more specific route without custom dialer via vpn",
|
||||||
|
destination: "10.10.0.2:53",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentRoutes(t *testing.T) {
|
||||||
|
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||||
|
intf := "lo0"
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 1024; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ip netip.Addr) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
|
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||||
|
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}(baseIP)
|
||||||
|
baseIP = baseIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||||
|
|
||||||
|
for i := 0; i < 1024; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ip netip.Addr) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
|
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||||
|
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}(baseIP)
|
||||||
|
baseIP = baseIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||||
|
require.NoError(t, err, "Failed to create loopback alias")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||||
|
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||||
|
})
|
||||||
|
|
||||||
|
return "lo0"
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var originalNexthop net.IP
|
||||||
|
if dstCIDR == "0.0.0.0/0" {
|
||||||
|
var err error
|
||||||
|
originalNexthop, err = fetchOriginalGateway()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Failed to fetch original gateway: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||||
|
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if originalNexthop != nil {
|
||||||
|
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||||
|
assert.NoError(t, err, "Failed to restore original route")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||||
|
require.NoError(t, err, "Failed to add route")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||||
|
assert.NoError(t, err, "Failed to remove route")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchOriginalGateway() (net.IP, error) {
|
||||||
|
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil, fmt.Errorf("gateway not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.ParseIP(matches[1]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||||
|
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
||||||
|
|
||||||
|
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||||
|
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
||||||
|
}
|
||||||
@@ -1,15 +1,33 @@
|
|||||||
//go:build ios
|
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRouting() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
func enableIPForwarding() error {
|
||||||
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addVPNRoute(netip.Prefix, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeVPNRoute(netip.Prefix, string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,149 +3,572 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
|
const (
|
||||||
// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
|
// NetbirdVPNTableID is the ID of the custom routing table used by Netbird.
|
||||||
type routeInfoInMemory struct {
|
NetbirdVPNTableID = 0x1BD0
|
||||||
Family byte
|
// NetbirdVPNTableName is the name of the custom routing table used by Netbird.
|
||||||
DstLen byte
|
NetbirdVPNTableName = "netbird"
|
||||||
SrcLen byte
|
|
||||||
TOS byte
|
|
||||||
|
|
||||||
Table byte
|
// rtTablesPath is the path to the file containing the routing table names.
|
||||||
Protocol byte
|
rtTablesPath = "/etc/iproute2/rt_tables"
|
||||||
Scope byte
|
|
||||||
Type byte
|
|
||||||
|
|
||||||
Flags uint32
|
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||||
|
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||||
|
|
||||||
|
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||||
|
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||||
|
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||||
|
|
||||||
|
var routeManager = &RouteManager{}
|
||||||
|
|
||||||
|
// originalSysctl stores the original sysctl values before they are modified
|
||||||
|
var originalSysctl map[string]int
|
||||||
|
|
||||||
|
// determines whether to use the legacy routing setup
|
||||||
|
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||||
|
|
||||||
|
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
|
||||||
|
var sysctlFailed bool
|
||||||
|
|
||||||
|
type ruleParams struct {
|
||||||
|
priority int
|
||||||
|
fwmark int
|
||||||
|
tableID int
|
||||||
|
family int
|
||||||
|
invert bool
|
||||||
|
suppressPrefix int
|
||||||
|
description string
|
||||||
}
|
}
|
||||||
|
|
||||||
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
func getSetupRules() []ruleParams {
|
||||||
|
return []ruleParams{
|
||||||
|
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
||||||
|
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
|
||||||
|
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
|
||||||
|
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
// setupRouting establishes the routing configuration for the VPN, including essential rules
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
|
||||||
|
//
|
||||||
|
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
|
||||||
|
// potential routes received and configured for the VPN. This rule is skipped for the default route and routes
|
||||||
|
// that are not in the main table.
|
||||||
|
//
|
||||||
|
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||||
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
|
// enabling VPN connectivity.
|
||||||
|
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||||
|
if isLegacy {
|
||||||
|
log.Infof("Using legacy routing setup")
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = addRoutingTableName(); err != nil {
|
||||||
|
log.Errorf("Error adding routing table name: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
originalValues, err := setupSysctl(wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.Errorf("Error setting up sysctl: %v", err)
|
||||||
|
sysctlFailed = true
|
||||||
|
}
|
||||||
|
originalSysctl = originalValues
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
if cleanErr := cleanupRouting(); cleanErr != nil {
|
||||||
|
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
rules := getSetupRules()
|
||||||
|
for _, rule := range rules {
|
||||||
|
if err := addRule(rule); err != nil {
|
||||||
|
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
|
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||||
|
isLegacy = true
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
addrMask := "/32"
|
return nil, nil, nil
|
||||||
if prefix.Addr().Unmap().Is6() {
|
}
|
||||||
addrMask = "/128"
|
|
||||||
|
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||||
|
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||||
|
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||||
|
func cleanupRouting() error {
|
||||||
|
if isLegacy {
|
||||||
|
return cleanupRoutingWithRouteManager(routeManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, _, err := net.ParseCIDR(addr + addrMask)
|
var result *multierror.Error
|
||||||
if err != nil {
|
|
||||||
return err
|
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err))
|
||||||
|
}
|
||||||
|
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
route := &netlink.Route{
|
rules := getSetupRules()
|
||||||
Scope: netlink.SCOPE_UNIVERSE,
|
for _, rule := range rules {
|
||||||
Dst: ipNet,
|
if err := removeRule(rule); err != nil {
|
||||||
Gw: ip,
|
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = netlink.RouteAdd(route)
|
if err := cleanupSysctl(originalSysctl); err != nil {
|
||||||
if err != nil {
|
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
|
||||||
return err
|
}
|
||||||
|
originalSysctl = nil
|
||||||
|
sysctlFailed = false
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if isLegacy {
|
||||||
|
return genericAddVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
|
||||||
|
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||||
|
|
||||||
|
// TODO remove this once we have ipv6 support
|
||||||
|
if prefix == defaultv4 {
|
||||||
|
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||||
|
return fmt.Errorf("add blackhole: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||||
|
return fmt.Errorf("add route: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeFromRouteTable(prefix netip.Prefix, addr string) error {
|
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
if isLegacy {
|
||||||
if err != nil {
|
return genericRemoveVPNRoute(prefix, intf)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
addrMask := "/32"
|
// TODO remove this once we have ipv6 support
|
||||||
if prefix.Addr().Unmap().Is6() {
|
if prefix == defaultv4 {
|
||||||
addrMask = "/128"
|
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||||
|
return fmt.Errorf("remove unreachable route: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||||
ip, _, err := net.ParseCIDR(addr + addrMask)
|
return fmt.Errorf("remove route: %w", err)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
route := &netlink.Route{
|
|
||||||
Scope: netlink.SCOPE_UNIVERSE,
|
|
||||||
Dst: ipNet,
|
|
||||||
Gw: ip,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = netlink.RouteDel(route)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
|
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("get v4 routes: %w", err)
|
||||||
}
|
}
|
||||||
msgs, err := syscall.ParseNetlinkMessage(tab)
|
v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("get v6 routes: %w", err)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
return append(v4Routes, v6Routes...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRoutes fetches routes from a specific routing table identified by tableID.
|
||||||
|
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||||
var prefixList []netip.Prefix
|
var prefixList []netip.Prefix
|
||||||
loop:
|
|
||||||
for _, m := range msgs {
|
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
||||||
switch m.Header.Type {
|
if err != nil {
|
||||||
case syscall.NLMSG_DONE:
|
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
|
||||||
break loop
|
}
|
||||||
case syscall.RTM_NEWROUTE:
|
|
||||||
rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
|
for _, route := range routes {
|
||||||
msg := m
|
if route.Dst != nil {
|
||||||
attrs, err := syscall.ParseNetlinkRouteAttr(&msg)
|
addr, ok := netip.AddrFromSlice(route.Dst.IP)
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
|
||||||
}
|
|
||||||
if rt.Family != syscall.AF_INET {
|
|
||||||
continue loop
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, attr := range attrs {
|
ones, _ := route.Dst.Mask.Size()
|
||||||
if attr.Attr.Type == syscall.RTA_DST {
|
|
||||||
addr, ok := netip.AddrFromSlice(attr.Value)
|
prefix := netip.PrefixFrom(addr, ones)
|
||||||
if !ok {
|
if prefix.IsValid() {
|
||||||
continue
|
prefixList = append(prefixList, prefix)
|
||||||
}
|
|
||||||
mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8)
|
|
||||||
cidr, _ := mask.Size()
|
|
||||||
routePrefix := netip.PrefixFrom(addr, cidr)
|
|
||||||
if routePrefix.IsValid() && routePrefix.Addr().Is4() {
|
|
||||||
prefixList = append(prefixList, routePrefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
// addRoute adds a route to a specific routing table identified by tableID.
|
||||||
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
||||||
if err != nil {
|
route := &netlink.Route{
|
||||||
return err
|
Scope: netlink.SCOPE_UNIVERSE,
|
||||||
|
Table: tableID,
|
||||||
|
Family: getAddressFamily(prefix),
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if it is already enabled
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
// see more: https://github.com/netbirdio/netbird/issues/872
|
if err != nil {
|
||||||
if len(bytes) > 0 && bytes[0] == 49 {
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
route.Dst = ipNet
|
||||||
|
|
||||||
|
if err := addNextHop(addr, intf, route); err != nil {
|
||||||
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("netlink add route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
|
||||||
|
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
|
||||||
|
// tableID specifies the routing table to which the unreachable route will be added.
|
||||||
|
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
route := &netlink.Route{
|
||||||
|
Type: syscall.RTN_UNREACHABLE,
|
||||||
|
Table: tableID,
|
||||||
|
Family: getAddressFamily(prefix),
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("netlink add unreachable route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
route := &netlink.Route{
|
||||||
|
Type: syscall.RTN_UNREACHABLE,
|
||||||
|
Table: tableID,
|
||||||
|
Family: getAddressFamily(prefix),
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoute removes a route from a specific routing table identified by tableID.
|
||||||
|
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
||||||
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
route := &netlink.Route{
|
||||||
|
Scope: netlink.SCOPE_UNIVERSE,
|
||||||
|
Table: tableID,
|
||||||
|
Family: getAddressFamily(prefix),
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := addNextHop(addr, intf, route); err != nil {
|
||||||
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("netlink remove route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flushRoutes(tableID, family int) error {
|
||||||
|
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list routes from table %d: %w", tableID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for i := range routes {
|
||||||
|
route := routes[i]
|
||||||
|
// unreachable default routes don't come back with Dst set
|
||||||
|
if route.Gw == nil && route.Src == nil && route.Dst == nil {
|
||||||
|
if family == netlink.FAMILY_V4 {
|
||||||
|
routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
|
||||||
|
} else {
|
||||||
|
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
func enableIPForwarding() error {
|
||||||
|
_, err := setSysctl(ipv4ForwardingPath, 1, false)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
||||||
|
// and verifies if existing names start with "netbird_".
|
||||||
|
func entryExists(file *os.File, id int) (bool, error) {
|
||||||
|
if _, err := file.Seek(0, 0); err != nil {
|
||||||
|
return false, fmt.Errorf("seek rt_tables: %w", err)
|
||||||
|
}
|
||||||
|
scanner := bufio.NewScanner(file)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
var existingID int
|
||||||
|
var existingName string
|
||||||
|
if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil {
|
||||||
|
if existingID == id {
|
||||||
|
if existingName != NetbirdVPNTableName {
|
||||||
|
return true, ErrTableIDExists
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return false, fmt.Errorf("scan rt_tables: %w", err)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRoutingTableName adds human-readable names for custom routing tables.
|
||||||
|
func addRoutingTableName() error {
|
||||||
|
file, err := os.Open(rtTablesPath)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("open rt_tables: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := file.Close(); err != nil {
|
||||||
|
log.Errorf("Error closing rt_tables: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
exists, err := entryExists(file, NetbirdVPNTableID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec
|
// Reopen the file in append mode to add new entries
|
||||||
|
if err := file.Close(); err != nil {
|
||||||
|
log.Errorf("Error closing rt_tables before appending: %v", err)
|
||||||
|
}
|
||||||
|
file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open rt_tables for appending: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil {
|
||||||
|
return fmt.Errorf("append entry to rt_tables: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRule adds a routing rule to a specific routing table identified by tableID.
|
||||||
|
func addRule(params ruleParams) error {
|
||||||
|
rule := netlink.NewRule()
|
||||||
|
rule.Table = params.tableID
|
||||||
|
rule.Mark = params.fwmark
|
||||||
|
rule.Family = params.family
|
||||||
|
rule.Priority = params.priority
|
||||||
|
rule.Invert = params.invert
|
||||||
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
|
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("add routing rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRule removes a routing rule from a specific routing table identified by tableID.
|
||||||
|
func removeRule(params ruleParams) error {
|
||||||
|
rule := netlink.NewRule()
|
||||||
|
rule.Table = params.tableID
|
||||||
|
rule.Mark = params.fwmark
|
||||||
|
rule.Family = params.family
|
||||||
|
rule.Invert = params.invert
|
||||||
|
rule.Priority = params.priority
|
||||||
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
|
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return fmt.Errorf("remove routing rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNextHop adds the gateway and device to the route.
|
||||||
|
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
|
||||||
|
if addr.IsValid() {
|
||||||
|
route.Gw = addr.AsSlice()
|
||||||
|
if intf == "" {
|
||||||
|
intf = addr.Zone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if intf != "" {
|
||||||
|
link, err := netlink.LinkByName(intf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("set interface %s: %w", intf, err)
|
||||||
|
}
|
||||||
|
route.LinkIndex = link.Attrs().Index
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAddressFamily(prefix netip.Prefix) int {
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
return netlink.FAMILY_V4
|
||||||
|
}
|
||||||
|
return netlink.FAMILY_V6
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupSysctl configures sysctl settings for RP filtering and source validation.
|
||||||
|
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
|
||||||
|
keys := map[string]int{}
|
||||||
|
var result *multierror.Error
|
||||||
|
|
||||||
|
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
} else {
|
||||||
|
keys[srcValidMarkPath] = oldVal
|
||||||
|
}
|
||||||
|
|
||||||
|
oldVal, err = setSysctl(rpFilterPath, 2, true)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
} else {
|
||||||
|
keys[rpFilterPath] = oldVal
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, intf := range interfaces {
|
||||||
|
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||||
|
oldVal, err := setSysctl(i, 2, true)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
} else {
|
||||||
|
keys[i] = oldVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||||
|
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||||
|
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||||
|
currentValue, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||||
|
if err != nil && len(currentValue) > 0 {
|
||||||
|
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||||
|
return currentV, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:gosec
|
||||||
|
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||||
|
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||||
|
|
||||||
|
return currentV, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupSysctl(originalSettings map[string]int) error {
|
||||||
|
var result *multierror.Error
|
||||||
|
|
||||||
|
for key, value := range originalSettings {
|
||||||
|
_, err := setSysctl(key, value, false)
|
||||||
|
if err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|||||||
207
client/internal/routemanager/systemops_linux_test.go
Normal file
207
client/internal/routemanager/systemops_linux_test.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
var expectedVPNint = "wgtest0"
|
||||||
|
var expectedLoopbackInt = "lo"
|
||||||
|
var expectedExternalInt = "dummyext0"
|
||||||
|
var expectedInternalInt = "dummyint0"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
testCases = append(testCases, []testCase{
|
||||||
|
{
|
||||||
|
name: "To more specific route without custom dialer via physical interface",
|
||||||
|
destination: "10.10.0.2:53",
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To more specific route (local) without custom dialer via physical interface",
|
||||||
|
destination: "127.0.10.1:53",
|
||||||
|
expectedInterface: expectedLoopbackInt,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEntryExists(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||||
|
|
||||||
|
content := []string{
|
||||||
|
"1000 reserved",
|
||||||
|
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||||
|
"9999 other_table",
|
||||||
|
}
|
||||||
|
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||||
|
|
||||||
|
file, err := os.Open(tempFilePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
assert.NoError(t, file.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
id int
|
||||||
|
shouldExist bool
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ExistsWithNetbirdPrefix",
|
||||||
|
id: 7120,
|
||||||
|
shouldExist: true,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ExistsWithDifferentName",
|
||||||
|
id: 1000,
|
||||||
|
shouldExist: true,
|
||||||
|
err: ErrTableIDExists,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DoesNotExist",
|
||||||
|
id: 1234,
|
||||||
|
shouldExist: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
exists, err := entryExists(file, tc.id)
|
||||||
|
if tc.err != nil {
|
||||||
|
assert.ErrorIs(t, err, tc.err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.shouldExist, exists)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
|
||||||
|
err := netlink.LinkDel(dummy)
|
||||||
|
if err != nil && !errors.Is(err, syscall.EINVAL) {
|
||||||
|
t.Logf("Failed to delete dummy interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = netlink.LinkAdd(dummy)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = netlink.LinkSetUp(dummy)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if ipAddressCIDR != "" {
|
||||||
|
addr, err := netlink.ParseAddr(ipAddressCIDR)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = netlink.AddrAdd(dummy, addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := netlink.LinkDel(dummy)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
return dummy.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Handle existing routes with metric 0
|
||||||
|
var originalNexthop net.IP
|
||||||
|
var originalLinkIndex int
|
||||||
|
if dstIPNet.String() == "0.0.0.0/0" {
|
||||||
|
var err error
|
||||||
|
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
|
||||||
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
|
t.Logf("Failed to fetch original gateway: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if originalNexthop != nil {
|
||||||
|
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
|
||||||
|
switch {
|
||||||
|
case err != nil && !errors.Is(err, syscall.ESRCH):
|
||||||
|
t.Logf("Failed to delete route: %v", err)
|
||||||
|
case err == nil:
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0})
|
||||||
|
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||||
|
t.Fatalf("Failed to add route: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
t.Logf("Failed to delete route: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
link, err := netlink.LinkByName(intf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
linkIndex := link.Attrs().Index
|
||||||
|
|
||||||
|
route := &netlink.Route{
|
||||||
|
Dst: dstIPNet,
|
||||||
|
Gw: gw,
|
||||||
|
LinkIndex: linkIndex,
|
||||||
|
}
|
||||||
|
err = netlink.RouteDel(route)
|
||||||
|
if err != nil && !errors.Is(err, syscall.ESRCH) {
|
||||||
|
t.Logf("Failed to delete route: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = netlink.RouteAdd(route)
|
||||||
|
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||||
|
t.Fatalf("Failed to add route: %v", err)
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchOriginalGateway(family int) (net.IP, int, error) {
|
||||||
|
routes, err := netlink.RouteList(nil, family)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
if route.Dst == nil && route.Priority == 0 {
|
||||||
|
return route.Gw, route.LinkIndex, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, 0, ErrRouteNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
|
||||||
|
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
||||||
|
|
||||||
|
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
|
||||||
|
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
||||||
|
}
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
//go:build !android && !ios
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/libp2p/go-netroute"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
var errRouteNotFound = fmt.Errorf("route not found")
|
|
||||||
|
|
||||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
|
||||||
ok, err := existsInRouteTable(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
log.Warnf("skipping adding a new route for network %s because it already exists", prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = isSubRange(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
err := addRouteForCurrentDefaultGateway(prefix)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return addToRouteTable(prefix, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|
||||||
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
|
||||||
if err != nil && err != errRouteNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := netip.MustParseAddr(defaultGateway.String())
|
|
||||||
|
|
||||||
if !prefix.Contains(addr) {
|
|
||||||
log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gatewayPrefix := netip.PrefixFrom(addr, 32)
|
|
||||||
|
|
||||||
ok, err := existsInRouteTable(gatewayPrefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix)
|
|
||||||
if err != nil && err != errRouteNotFound {
|
|
||||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
|
||||||
}
|
|
||||||
log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
|
||||||
return addToRouteTable(gatewayPrefix, gatewayHop.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := getRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if tableRoute == prefix {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := getRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
|
||||||
return removeFromRouteTable(prefix, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
|
|
||||||
r, err := netroute.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("getting routes returned an error: %v", err)
|
|
||||||
return nil, errRouteNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
if gateway == nil {
|
|
||||||
return preferredSrc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return gateway, nil
|
|
||||||
}
|
|
||||||
@@ -1,41 +1,23 @@
|
|||||||
//go:build !linux
|
//go:build !linux && !ios
|
||||||
// +build !linux
|
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
|
||||||
cmd := exec.Command("route", "add", prefix.String(), addr)
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debugf(string(out))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeFromRouteTable(prefix netip.Prefix, addr string) error {
|
|
||||||
args := []string{"delete", prefix.String()}
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
args = append(args, addr)
|
|
||||||
}
|
|
||||||
cmd := exec.Command("route", args...)
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debugf(string(out))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
func enableIPForwarding() error {
|
||||||
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
return genericAddVPNRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
return genericRemoveVPNRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,24 +1,32 @@
|
|||||||
//go:build !android
|
//go:build !android && !ios
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type dialer interface {
|
||||||
|
Dial(network, address string) (net.Conn, error)
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAddRemoveRoutes(t *testing.T) {
|
func TestAddRemoveRoutes(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -53,27 +61,30 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
|
|
||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
_, _, err = setupRouting(nil, wgInterface)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, cleanupRouting())
|
||||||
|
})
|
||||||
|
|
||||||
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String())
|
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||||
|
|
||||||
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
|
||||||
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
|
||||||
if testCase.shouldRouteToWireguard {
|
if testCase.shouldRouteToWireguard {
|
||||||
require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
||||||
} else {
|
} else {
|
||||||
require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface")
|
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
|
||||||
}
|
}
|
||||||
exists, err := existsInRouteTable(testCase.prefix)
|
exists, err := existsInRouteTable(testCase.prefix)
|
||||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||||
if exists && testCase.shouldRouteToWireguard {
|
if exists && testCase.shouldRouteToWireguard {
|
||||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String())
|
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err")
|
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||||
|
|
||||||
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
|
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
|
||||||
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
require.NoError(t, err, "getNextHop should not return err")
|
||||||
|
|
||||||
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if testCase.shouldBeRemoved {
|
if testCase.shouldBeRemoved {
|
||||||
@@ -86,12 +97,12 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetExistingRIBRouteGateway(t *testing.T) {
|
func TestGetNextHop(t *testing.T) {
|
||||||
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
}
|
}
|
||||||
if gateway == nil {
|
if !gateway.IsValid() {
|
||||||
t.Fatal("should return a gateway")
|
t.Fatal("should return a gateway")
|
||||||
}
|
}
|
||||||
addresses, err := net.InterfaceAddrs()
|
addresses, err := net.InterfaceAddrs()
|
||||||
@@ -113,11 +124,11 @@ func TestGetExistingRIBRouteGateway(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localIP, err := getExistingRIBRouteGateway(testingPrefix)
|
localIP, _, err := getNextHop(testingPrefix.Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error: ", err)
|
t.Fatal("shouldn't return error: ", err)
|
||||||
}
|
}
|
||||||
if localIP == nil {
|
if !localIP.IsValid() {
|
||||||
t.Fatal("should return a gateway for local network")
|
t.Fatal("should return a gateway for local network")
|
||||||
}
|
}
|
||||||
if localIP.String() == gateway.String() {
|
if localIP.String() == gateway.String() {
|
||||||
@@ -128,8 +139,8 @@ func TestGetExistingRIBRouteGateway(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||||
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
t.Log("defaultGateway: ", defaultGateway)
|
t.Log("defaultGateway: ", defaultGateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
@@ -189,16 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
|||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
|
||||||
MockAddr := wgInterface.Address().IP.String()
|
|
||||||
|
|
||||||
// Prepare the environment
|
// Prepare the environment
|
||||||
if testCase.preExistingPrefix.IsValid() {
|
if testCase.preExistingPrefix.IsValid() {
|
||||||
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr)
|
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the route
|
// Add the route
|
||||||
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr)
|
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "should not return err when adding route")
|
require.NoError(t, err, "should not return err when adding route")
|
||||||
|
|
||||||
if testCase.shouldAddRoute {
|
if testCase.shouldAddRoute {
|
||||||
@@ -208,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
|||||||
require.True(t, ok, "route should exist")
|
require.True(t, ok, "route should exist")
|
||||||
|
|
||||||
// remove route again if added
|
// remove route again if added
|
||||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr)
|
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,6 +226,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
|||||||
ok, err := existsInRouteTable(testCase.prefix)
|
ok, err := existsInRouteTable(testCase.prefix)
|
||||||
t.Log("Buffer string: ", buf.String())
|
t.Log("Buffer string: ", buf.String())
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
|
|
||||||
if !strings.Contains(buf.String(), "because it already exists") {
|
if !strings.Contains(buf.String(), "because it already exists") {
|
||||||
require.False(t, ok, "route should not exist")
|
require.False(t, ok, "route should not exist")
|
||||||
}
|
}
|
||||||
@@ -224,31 +234,6 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExistsInRouteTable(t *testing.T) {
|
|
||||||
addresses, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var addressPrefixes []netip.Prefix
|
|
||||||
for _, address := range addresses {
|
|
||||||
p := netip.MustParsePrefix(address.String())
|
|
||||||
if p.Addr().Is4() {
|
|
||||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range addressPrefixes {
|
|
||||||
exists, err := existsInRouteTable(prefix)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
t.Fatalf("address %s should exist in route table", prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsSubRange(t *testing.T) {
|
func TestIsSubRange(t *testing.T) {
|
||||||
addresses, err := net.InterfaceAddrs()
|
addresses, err := net.InterfaceAddrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -286,3 +271,132 @@ func TestIsSubRange(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExistsInRouteTable(t *testing.T) {
|
||||||
|
addresses, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var addressPrefixes []netip.Prefix
|
||||||
|
for _, address := range addresses {
|
||||||
|
p := netip.MustParsePrefix(address.String())
|
||||||
|
if p.Addr().Is6() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||||
|
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||||
|
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range addressPrefixes {
|
||||||
|
exists, err := existsInRouteTable(prefix)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("address %s should exist in route table", prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newNet, err := stdnet.NewNet()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||||
|
require.NoError(t, err, "should create testing WireGuard interface")
|
||||||
|
|
||||||
|
err = wgInterface.Create()
|
||||||
|
require.NoError(t, err, "should create testing WireGuard interface")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
wgInterface.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return wgInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestEnv(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
setupDummyInterfacesAndRoutes(t)
|
||||||
|
|
||||||
|
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, wgIface.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
_, _, err := setupRouting(nil, wgIface)
|
||||||
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, cleanupRouting())
|
||||||
|
})
|
||||||
|
|
||||||
|
// default route exists in main table and vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 10.0.0.0/8 route exists in main table and vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 10.10.0.0/24 more specific route exists in vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 127.0.10.0/24 more specific route exists in vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// unique route in vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||||
|
t.Helper()
|
||||||
|
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixGateway, _, err := getNextHop(prefix.Addr())
|
||||||
|
require.NoError(t, err, "getNextHop should not return err")
|
||||||
|
if invert {
|
||||||
|
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
234
client/internal/routemanager/systemops_unix_test.go
Normal file
234
client/internal/routemanager/systemops_unix_test.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gopacket/gopacket"
|
||||||
|
"github.com/gopacket/gopacket/layers"
|
||||||
|
"github.com/gopacket/gopacket/pcap"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketExpectation struct {
|
||||||
|
SrcIP net.IP
|
||||||
|
DstIP net.IP
|
||||||
|
SrcPort int
|
||||||
|
DstPort int
|
||||||
|
UDP bool
|
||||||
|
TCP bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
destination string
|
||||||
|
expectedInterface string
|
||||||
|
dialer dialer
|
||||||
|
expectedPacket PacketExpectation
|
||||||
|
}
|
||||||
|
|
||||||
|
var testCases = []testCase{
|
||||||
|
{
|
||||||
|
name: "To external host without custom dialer via vpn",
|
||||||
|
destination: "192.0.2.1:53",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To external host with custom dialer via physical interface",
|
||||||
|
destination: "192.0.2.1:53",
|
||||||
|
expectedInterface: expectedExternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route with custom dialer via physical interface",
|
||||||
|
destination: "10.0.0.2:53",
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||||
|
destination: "10.0.0.2:53",
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To unique vpn route with custom dialer via physical interface",
|
||||||
|
destination: "172.16.0.2:53",
|
||||||
|
expectedInterface: expectedExternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To unique vpn route without custom dialer via vpn",
|
||||||
|
destination: "172.16.0.2:53",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouting(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
setupTestEnv(t)
|
||||||
|
|
||||||
|
filter := createBPFFilter(tc.destination)
|
||||||
|
handle := startPacketCapture(t, tc.expectedInterface, filter)
|
||||||
|
|
||||||
|
sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer)
|
||||||
|
|
||||||
|
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
||||||
|
packet, err := packetSource.NextPacket()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
verifyPacket(t, packet, tc.expectedPacket)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||||
|
return PacketExpectation{
|
||||||
|
SrcIP: net.ParseIP(srcIP),
|
||||||
|
DstIP: net.ParseIP(dstIP),
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
UDP: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
inactive, err := pcap.NewInactiveHandle(intf)
|
||||||
|
require.NoError(t, err, "Failed to create inactive pcap handle")
|
||||||
|
defer inactive.CleanUp()
|
||||||
|
|
||||||
|
err = inactive.SetSnapLen(1600)
|
||||||
|
require.NoError(t, err, "Failed to set snap length on inactive handle")
|
||||||
|
|
||||||
|
err = inactive.SetTimeout(time.Second * 10)
|
||||||
|
require.NoError(t, err, "Failed to set timeout on inactive handle")
|
||||||
|
|
||||||
|
err = inactive.SetImmediateMode(true)
|
||||||
|
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
|
||||||
|
|
||||||
|
handle, err := inactive.Activate()
|
||||||
|
require.NoError(t, err, "Failed to activate pcap handle")
|
||||||
|
t.Cleanup(handle.Close)
|
||||||
|
|
||||||
|
err = handle.SetBPFFilter(filter)
|
||||||
|
require.NoError(t, err, "Failed to set BPF filter")
|
||||||
|
|
||||||
|
return handle
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if dialer == nil {
|
||||||
|
dialer = &net.Dialer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sourcePort != 0 {
|
||||||
|
localUDPAddr := &net.UDPAddr{
|
||||||
|
IP: net.IPv4zero,
|
||||||
|
Port: sourcePort,
|
||||||
|
}
|
||||||
|
switch dialer := dialer.(type) {
|
||||||
|
case *nbnet.Dialer:
|
||||||
|
dialer.LocalAddr = localUDPAddr
|
||||||
|
case *net.Dialer:
|
||||||
|
dialer.LocalAddr = localUDPAddr
|
||||||
|
default:
|
||||||
|
t.Fatal("Unsupported dialer type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.Id = dns.Id()
|
||||||
|
msg.RecursionDesired = true
|
||||||
|
msg.Question = []dns.Question{
|
||||||
|
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialer.Dial("udp", destination)
|
||||||
|
require.NoError(t, err, "Failed to dial UDP")
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
data, err := msg.Pack()
|
||||||
|
require.NoError(t, err, "Failed to pack DNS message")
|
||||||
|
|
||||||
|
_, err = conn.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "required key not available") {
|
||||||
|
t.Logf("Ignoring WireGuard key error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("Failed to send DNS query: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createBPFFilter(destination string) string {
|
||||||
|
host, port, err := net.SplitHostPort(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
|
||||||
|
}
|
||||||
|
return "udp"
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ipLayer := packet.Layer(layers.LayerTypeIPv4)
|
||||||
|
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
|
||||||
|
|
||||||
|
ip, ok := ipLayer.(*layers.IPv4)
|
||||||
|
require.True(t, ok, "Failed to cast to IPv4 layer")
|
||||||
|
|
||||||
|
// Convert both source and destination IP addresses to 16-byte representation
|
||||||
|
expectedSrcIP := exp.SrcIP.To16()
|
||||||
|
actualSrcIP := ip.SrcIP.To16()
|
||||||
|
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
|
||||||
|
|
||||||
|
expectedDstIP := exp.DstIP.To16()
|
||||||
|
actualDstIP := ip.DstIP.To16()
|
||||||
|
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
|
||||||
|
|
||||||
|
if exp.UDP {
|
||||||
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||||
|
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
|
||||||
|
|
||||||
|
udp, ok := udpLayer.(*layers.UDP)
|
||||||
|
require.True(t, ok, "Failed to cast to UDP layer")
|
||||||
|
|
||||||
|
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
|
||||||
|
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if exp.TCP {
|
||||||
|
tcpLayer := packet.Layer(layers.LayerTypeTCP)
|
||||||
|
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
|
||||||
|
|
||||||
|
tcp, ok := tcpLayer.(*layers.TCP)
|
||||||
|
require.True(t, ok, "Failed to cast to TCP layer")
|
||||||
|
|
||||||
|
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
|
||||||
|
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,13 +1,19 @@
|
|||||||
//go:build windows
|
//go:build windows
|
||||||
// +build windows
|
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/yusufpapurcu/wmi"
|
"github.com/yusufpapurcu/wmi"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Win32_IP4RouteTable struct {
|
type Win32_IP4RouteTable struct {
|
||||||
@@ -15,23 +21,35 @@ type Win32_IP4RouteTable struct {
|
|||||||
Mask string
|
Mask string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var routeManager *RouteManager
|
||||||
|
|
||||||
|
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRouting() error {
|
||||||
|
return cleanupRoutingWithRouteManager(routeManager)
|
||||||
|
}
|
||||||
|
|
||||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
var routes []Win32_IP4RouteTable
|
var routes []Win32_IP4RouteTable
|
||||||
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
||||||
|
|
||||||
err := wmi.Query(query, &routes)
|
err := wmi.Query(query, &routes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("get routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefixList []netip.Prefix
|
var prefixList []netip.Prefix
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
addr, err := netip.ParseAddr(route.Destination)
|
addr, err := netip.ParseAddr(route.Destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Warnf("Unable to parse route destination %s: %v", route.Destination, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
maskSlice := net.ParseIP(route.Mask).To4()
|
maskSlice := net.ParseIP(route.Mask).To4()
|
||||||
if maskSlice == nil {
|
if maskSlice == nil {
|
||||||
|
log.Warnf("Unable to parse route mask %s", route.Mask)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
|
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
|
||||||
@@ -44,3 +62,86 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
}
|
}
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error {
|
||||||
|
destinationPrefix := prefix.String()
|
||||||
|
psCmd := "New-NetRoute"
|
||||||
|
|
||||||
|
addressFamily := "IPv4"
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
addressFamily = "IPv6"
|
||||||
|
}
|
||||||
|
|
||||||
|
script := fmt.Sprintf(
|
||||||
|
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
|
||||||
|
psCmd, addressFamily, destinationPrefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
if intfIdx != "" {
|
||||||
|
script = fmt.Sprintf(
|
||||||
|
`%s -InterfaceIndex %s`, script, intfIdx,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
script = fmt.Sprintf(
|
||||||
|
`%s -InterfaceAlias "%s"`, script, intf,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nexthop.IsValid() {
|
||||||
|
script = fmt.Sprintf(
|
||||||
|
`%s -NextHop "%s"`, script, nexthop,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||||
|
log.Tracef("PowerShell %s: %s", script, string(out))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("PowerShell add route: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
|
||||||
|
args := []string{"add", prefix.String(), nexthop.Unmap().String()}
|
||||||
|
|
||||||
|
out, err := exec.Command("route", args...).CombinedOutput()
|
||||||
|
|
||||||
|
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route add: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
var intfIdx string
|
||||||
|
if nexthop.Zone() != "" {
|
||||||
|
intfIdx = nexthop.Zone()
|
||||||
|
nexthop.WithZone("")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Powershell doesn't support adding routes without an interface but allows to add interface by name
|
||||||
|
if intf != "" || intfIdx != "" {
|
||||||
|
return addRoutePowershell(prefix, nexthop, intf, intfIdx)
|
||||||
|
}
|
||||||
|
return addRouteCmd(prefix, nexthop, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
|
||||||
|
args := []string{"delete", prefix.String()}
|
||||||
|
if nexthop.IsValid() {
|
||||||
|
nexthop.WithZone("")
|
||||||
|
args = append(args, nexthop.Unmap().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command("route", args...).CombinedOutput()
|
||||||
|
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("remove route: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
289
client/internal/routemanager/systemops_windows_test.go
Normal file
289
client/internal/routemanager/systemops_windows_test.go
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
var expectedExtInt = "Ethernet1"
|
||||||
|
|
||||||
|
type RouteInfo struct {
|
||||||
|
NextHop string `json:"nexthop"`
|
||||||
|
InterfaceAlias string `json:"interfacealias"`
|
||||||
|
RouteMetric int `json:"routemetric"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FindNetRouteOutput struct {
|
||||||
|
IPAddress string `json:"IPAddress"`
|
||||||
|
InterfaceIndex int `json:"InterfaceIndex"`
|
||||||
|
InterfaceAlias string `json:"InterfaceAlias"`
|
||||||
|
AddressFamily int `json:"AddressFamily"`
|
||||||
|
NextHop string `json:"NextHop"`
|
||||||
|
DestinationPrefix string `json:"DestinationPrefix"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
destination string
|
||||||
|
expectedSourceIP string
|
||||||
|
expectedDestPrefix string
|
||||||
|
expectedNextHop string
|
||||||
|
expectedInterface string
|
||||||
|
dialer dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
var expectedVPNint = "wgtest0"
|
||||||
|
|
||||||
|
var testCases = []testCase{
|
||||||
|
{
|
||||||
|
name: "To external host without custom dialer via vpn",
|
||||||
|
destination: "192.0.2.1:53",
|
||||||
|
expectedSourceIP: "100.64.0.1",
|
||||||
|
expectedDestPrefix: "128.0.0.0/1",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "wgtest0",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To external host with custom dialer via physical interface",
|
||||||
|
destination: "192.0.2.1:53",
|
||||||
|
expectedDestPrefix: "192.0.2.1/32",
|
||||||
|
expectedInterface: expectedExtInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route with custom dialer via physical interface",
|
||||||
|
destination: "10.0.0.2:53",
|
||||||
|
expectedDestPrefix: "10.0.0.2/32",
|
||||||
|
expectedInterface: expectedExtInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||||
|
destination: "10.0.0.2:53",
|
||||||
|
expectedSourceIP: "10.0.0.1",
|
||||||
|
expectedDestPrefix: "10.0.0.0/8",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To unique vpn route with custom dialer via physical interface",
|
||||||
|
destination: "172.16.0.2:53",
|
||||||
|
expectedDestPrefix: "172.16.0.2/32",
|
||||||
|
expectedInterface: expectedExtInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To unique vpn route without custom dialer via vpn",
|
||||||
|
destination: "172.16.0.2:53",
|
||||||
|
expectedSourceIP: "100.64.0.1",
|
||||||
|
expectedDestPrefix: "172.16.0.0/12",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "wgtest0",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To more specific route without custom dialer via vpn interface",
|
||||||
|
destination: "10.10.0.2:53",
|
||||||
|
expectedSourceIP: "100.64.0.1",
|
||||||
|
expectedDestPrefix: "10.10.0.0/24",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "wgtest0",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To more specific route (local) without custom dialer via physical interface",
|
||||||
|
destination: "127.0.10.2:53",
|
||||||
|
expectedSourceIP: "10.0.0.1",
|
||||||
|
expectedDestPrefix: "127.0.0.0/8",
|
||||||
|
expectedNextHop: "0.0.0.0",
|
||||||
|
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouting(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
setupTestEnv(t)
|
||||||
|
|
||||||
|
route, err := fetchOriginalGateway()
|
||||||
|
require.NoError(t, err, "Failed to fetch original gateway")
|
||||||
|
ip, err := fetchInterfaceIP(route.InterfaceAlias)
|
||||||
|
require.NoError(t, err, "Failed to fetch interface IP")
|
||||||
|
|
||||||
|
output := testRoute(t, tc.destination, tc.dialer)
|
||||||
|
if tc.expectedInterface == expectedExtInt {
|
||||||
|
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
|
||||||
|
} else {
|
||||||
|
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchInterfaceIP fetches the IPv4 address of the specified interface.
|
||||||
|
func fetchInterfaceIP(interfaceAlias string) (string, error) {
|
||||||
|
script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias)
|
||||||
|
out, err := exec.Command("powershell", "-Command", script).Output()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := strings.TrimSpace(string(out))
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := dialer.DialContext(ctx, "udp", destination)
|
||||||
|
require.NoError(t, err, "Failed to dial destination")
|
||||||
|
defer func() {
|
||||||
|
err := conn.Close()
|
||||||
|
assert.NoError(t, err, "Failed to close connection")
|
||||||
|
}()
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(destination)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
|
||||||
|
|
||||||
|
out, err := exec.Command("powershell", "-Command", script).Output()
|
||||||
|
require.NoError(t, err, "Failed to execute Find-NetRoute")
|
||||||
|
|
||||||
|
var outputs []FindNetRouteOutput
|
||||||
|
err = json.Unmarshal(out, &outputs)
|
||||||
|
require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute")
|
||||||
|
|
||||||
|
require.Greater(t, len(outputs), 0, "No route found for destination")
|
||||||
|
combinedOutput := combineOutputs(outputs)
|
||||||
|
|
||||||
|
return combinedOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
|
||||||
|
require.NoError(t, err)
|
||||||
|
subnetMaskSize, _ := ipNet.Mask.Size()
|
||||||
|
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
|
||||||
|
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
|
||||||
|
|
||||||
|
// Wait for the IP address to be applied
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
err = waitForIPAddress(ctx, interfaceName, ip.String())
|
||||||
|
require.NoError(t, err, "IP address not applied within timeout")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
|
||||||
|
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
|
||||||
|
})
|
||||||
|
|
||||||
|
return interfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchOriginalGateway() (*RouteInfo, error) {
|
||||||
|
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var routeInfo RouteInfo
|
||||||
|
err = json.Unmarshal(output, &routeInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JSON output: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &routeInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch")
|
||||||
|
assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch")
|
||||||
|
assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch")
|
||||||
|
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||||
|
for _, ip := range ipAddresses {
|
||||||
|
if strings.TrimSpace(ip) == expectedIPAddress {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
|
||||||
|
var combined FindNetRouteOutput
|
||||||
|
|
||||||
|
for _, output := range outputs {
|
||||||
|
if output.IPAddress != "" {
|
||||||
|
combined.IPAddress = output.IPAddress
|
||||||
|
}
|
||||||
|
if output.InterfaceIndex != 0 {
|
||||||
|
combined.InterfaceIndex = output.InterfaceIndex
|
||||||
|
}
|
||||||
|
if output.InterfaceAlias != "" {
|
||||||
|
combined.InterfaceAlias = output.InterfaceAlias
|
||||||
|
}
|
||||||
|
if output.AddressFamily != 0 {
|
||||||
|
combined.AddressFamily = output.AddressFamily
|
||||||
|
}
|
||||||
|
if output.NextHop != "" {
|
||||||
|
combined.NextHop = output.NextHop
|
||||||
|
}
|
||||||
|
if output.DestinationPrefix != "" {
|
||||||
|
combined.DestinationPrefix = output.DestinationPrefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &combined
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
|
||||||
|
}
|
||||||
82
client/internal/session.go
Normal file
82
client/internal/session.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionWatcher struct {
|
||||||
|
ctx context.Context
|
||||||
|
mutex sync.Mutex
|
||||||
|
|
||||||
|
peerStatusRecorder *peer.Status
|
||||||
|
watchTicker *time.Ticker
|
||||||
|
|
||||||
|
sendNotification bool
|
||||||
|
onExpireListener func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionWatcher creates a new instance of SessionWatcher.
|
||||||
|
func NewSessionWatcher(ctx context.Context, peerStatusRecorder *peer.Status) *SessionWatcher {
|
||||||
|
s := &SessionWatcher{
|
||||||
|
ctx: ctx,
|
||||||
|
peerStatusRecorder: peerStatusRecorder,
|
||||||
|
watchTicker: time.NewTicker(2 * time.Second),
|
||||||
|
}
|
||||||
|
go s.startWatcher()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOnExpireListener sets the callback func to be called when the session expires.
|
||||||
|
func (s *SessionWatcher) SetOnExpireListener(onExpire func()) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
s.onExpireListener = onExpire
|
||||||
|
}
|
||||||
|
|
||||||
|
// startWatcher continuously checks if the session requires login and
|
||||||
|
// calls the onExpireListener if login is required.
|
||||||
|
func (s *SessionWatcher) startWatcher() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
s.watchTicker.Stop()
|
||||||
|
return
|
||||||
|
case <-s.watchTicker.C:
|
||||||
|
managementState := s.peerStatusRecorder.GetManagementState()
|
||||||
|
if managementState.Connected {
|
||||||
|
s.sendNotification = true
|
||||||
|
}
|
||||||
|
|
||||||
|
isLoginRequired := s.peerStatusRecorder.IsLoginRequired()
|
||||||
|
if isLoginRequired && s.sendNotification && s.onExpireListener != nil {
|
||||||
|
s.mutex.Lock()
|
||||||
|
s.onExpireListener()
|
||||||
|
s.sendNotification = false
|
||||||
|
s.mutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUIApp checks whether UI application is running.
|
||||||
|
func CheckUIApp() bool {
|
||||||
|
cmd := exec.Command("ps", "-ef")
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(output), "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
if strings.Contains(line, "netbird-ui") && !strings.Contains(line, "grep") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
24
client/internal/stdnet/dialer.go
Normal file
24
client/internal/stdnet/dialer.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package stdnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dial connects to the address on the named network.
|
||||||
|
func (n *Net) Dial(network, address string) (net.Conn, error) {
|
||||||
|
return nbnet.NewDialer().Dial(network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialUDP connects to the address on the named UDP network.
|
||||||
|
func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
return nbnet.DialUDP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialTCP connects to the address on the named TCP network.
|
||||||
|
func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
|
||||||
|
return nbnet.DialTCP(network, laddr, raddr)
|
||||||
|
}
|
||||||
20
client/internal/stdnet/listener.go
Normal file
20
client/internal/stdnet/listener.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package stdnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListenPacket listens for incoming packets on the given network and address.
|
||||||
|
func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) {
|
||||||
|
return nbnet.NewListener().ListenPacket(context.Background(), network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenUDP acts like ListenPacket for UDP networks.
|
||||||
|
func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
return nbnet.ListenUDP(network, locAddr)
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user