mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
Compare commits
254 Commits
v0.13.0
...
proxy_cfg_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca10541f50 | ||
|
|
9b9f5fb64b | ||
|
|
3d3de0f1dd | ||
|
|
251f2d7bc2 | ||
|
|
94c6d8c55b | ||
|
|
306e02d32b | ||
|
|
8375491708 | ||
|
|
0f449f64ec | ||
|
|
a7953cef58 | ||
|
|
2bf846b183 | ||
|
|
e197b89ac3 | ||
|
|
6aba28ccb7 | ||
|
|
2a04175238 | ||
|
|
84467c1531 | ||
|
|
44b7afc06b | ||
|
|
dea45633a6 | ||
|
|
14939202f0 | ||
|
|
c5a8cc59a8 | ||
|
|
3d3400ab43 | ||
|
|
cdf5368d20 | ||
|
|
0803f16a49 | ||
|
|
b6aafd8f09 | ||
|
|
4c9fc9850c | ||
|
|
0e9d5807d6 | ||
|
|
b5a3248f9f | ||
|
|
5e1dfb28c0 | ||
|
|
d1fe03a2d4 | ||
|
|
bb453d334b | ||
|
|
56737bab6c | ||
|
|
37d20671a9 | ||
|
|
3f717eb759 | ||
|
|
f9f8cbdcaa | ||
|
|
a3849b978b | ||
|
|
1c071e4981 | ||
|
|
189321f09d | ||
|
|
db5f931373 | ||
|
|
b556736b31 | ||
|
|
01e19a7c67 | ||
|
|
9494cbdf24 | ||
|
|
9c09b13a25 | ||
|
|
8bb999cf2a | ||
|
|
565b8ce1c7 | ||
|
|
e96a975737 | ||
|
|
8585e3ccf7 | ||
|
|
cd002c6400 | ||
|
|
0629697db1 | ||
|
|
25a92a0052 | ||
|
|
9c51d85cb4 | ||
|
|
9865179207 | ||
|
|
56f10085f4 | ||
|
|
a7574907ae | ||
|
|
71e81533bc | ||
|
|
23b92e2615 | ||
|
|
9158a4653a | ||
|
|
ccbf749171 | ||
|
|
dea7e8d4e7 | ||
|
|
a0441e7d04 | ||
|
|
9702946474 | ||
|
|
e262f3536e | ||
|
|
addfed3af0 | ||
|
|
bf723ec66f | ||
|
|
10afc8eeb8 | ||
|
|
0b21e05a52 | ||
|
|
94c646f1e5 | ||
|
|
4f7d34c5c7 | ||
|
|
0455e574b8 | ||
|
|
965ba8837f | ||
|
|
61146a51d0 | ||
|
|
8f9826b207 | ||
|
|
4f8a156cb2 | ||
|
|
ff0b395fc5 | ||
|
|
0aad9169e9 | ||
|
|
237bfde1f2 | ||
|
|
bfff0c36aa | ||
|
|
41458a09e9 | ||
|
|
abd8287da8 | ||
|
|
1057cd211d | ||
|
|
d3c49c71f2 | ||
|
|
49030ab71e | ||
|
|
7548780f8f | ||
|
|
277b65b833 | ||
|
|
071ad2b993 | ||
|
|
32b345991a | ||
|
|
0e8a552334 | ||
|
|
005c4dd44a | ||
|
|
e903522f8c | ||
|
|
ea88ec6d27 | ||
|
|
2be1a82f4a | ||
|
|
367eff493a | ||
|
|
73a5bc33b3 | ||
|
|
87cbff1e7a | ||
|
|
fe1ea4a2d0 | ||
|
|
f14f34cf2b | ||
|
|
24cc5c4ef2 | ||
|
|
109481e26d | ||
|
|
18098e7a7d | ||
|
|
5993982cca | ||
|
|
a42f7d2c3b | ||
|
|
e2a3fc7558 | ||
|
|
5d191a8b9d | ||
|
|
86f9051a30 | ||
|
|
489892553a | ||
|
|
b05e30ac5a | ||
|
|
769388cd21 | ||
|
|
c54fb9643c | ||
|
|
5dc0ff42a5 | ||
|
|
45badd2c39 | ||
|
|
d3de035961 | ||
|
|
79a8109d5e | ||
|
|
b2da0ae70f | ||
|
|
931c20c8fe | ||
|
|
2eaf4aa8d7 | ||
|
|
110067c00f | ||
|
|
32c96c15b8 | ||
|
|
ca1dc5ac88 | ||
|
|
ce775d59ae | ||
|
|
f273fe9f51 | ||
|
|
e08af7fcdf | ||
|
|
454240ca05 | ||
|
|
1343a3f00e | ||
|
|
2a79995706 | ||
|
|
e869882da1 | ||
|
|
6c8bb60632 | ||
|
|
4d7029d80c | ||
|
|
909f305728 | ||
|
|
5e2f66d591 | ||
|
|
ea44c1b723 | ||
|
|
a7519859bc | ||
|
|
9b000b89d5 | ||
|
|
5c1acdbf2f | ||
|
|
db3a9f0aa2 | ||
|
|
ecc4f8a10d | ||
|
|
03abdfa112 | ||
|
|
9746a7f61a | ||
|
|
4ec6d5d20b | ||
|
|
3bab745142 | ||
|
|
0ca3d27a80 | ||
|
|
c5942e6b33 | ||
|
|
726ffb5740 | ||
|
|
430f92094e | ||
|
|
dfb7960cd4 | ||
|
|
ab0cf1b8aa | ||
|
|
8ebd6ce963 | ||
|
|
42ba0765c8 | ||
|
|
514403db37 | ||
|
|
488d338ce8 | ||
|
|
d6c2b46019 | ||
|
|
6a75ec4ab7 | ||
|
|
b66e984ddd | ||
|
|
c65a934107 | ||
|
|
2e7d199a6d | ||
|
|
55ebf93815 | ||
|
|
9e74f30d2f | ||
|
|
8ac7eaf833 | ||
|
|
71d24e59e6 | ||
|
|
992cfe64e1 | ||
|
|
d1703479ff | ||
|
|
a27fe4326c | ||
|
|
e6292e3124 | ||
|
|
628b497e81 | ||
|
|
8f66dea11c | ||
|
|
de8608f99f | ||
|
|
9c5adfea2b | ||
|
|
8e4710763e | ||
|
|
82af60838e | ||
|
|
311b67fe5a | ||
|
|
94d39ab48c | ||
|
|
41a47be379 | ||
|
|
e30def175b | ||
|
|
e1ef091d45 | ||
|
|
511ba6d51f | ||
|
|
b852198f67 | ||
|
|
891ba277b1 | ||
|
|
747797271e | ||
|
|
628a201e31 | ||
|
|
731d3ae464 | ||
|
|
453643683d | ||
|
|
b8cab2882b | ||
|
|
6143b819c5 | ||
|
|
3b42d5e48a | ||
|
|
1d4dfa41d2 | ||
|
|
f8db5742b5 | ||
|
|
bc3cec23ec | ||
|
|
f03aadf064 | ||
|
|
292ee260ad | ||
|
|
2a1efbd0fd | ||
|
|
3bfa26b13b | ||
|
|
221934447e | ||
|
|
9ce8056b17 | ||
|
|
c65a5acab9 | ||
|
|
62de082961 | ||
|
|
c4d9b76634 | ||
|
|
b4bb5c6bb8 | ||
|
|
2b1965c941 | ||
|
|
83e7e30218 | ||
|
|
24310c63e2 | ||
|
|
ed4f90b6aa | ||
|
|
0e9610c5b2 | ||
|
|
ed470d7dbe | ||
|
|
cb8abacadd | ||
|
|
bcac5f7b32 | ||
|
|
95d87384ab | ||
|
|
ea3899e6d6 | ||
|
|
337d3edcc4 | ||
|
|
e914adb5cd | ||
|
|
2f2d45de9e | ||
|
|
b3f339c753 | ||
|
|
e0fc779f58 | ||
|
|
f64e0754ee | ||
|
|
fe22eb3b98 | ||
|
|
69be2a8071 | ||
|
|
1bda8fd563 | ||
|
|
1ab791e91b | ||
|
|
41948f7919 | ||
|
|
60f67076b0 | ||
|
|
c645171c40 | ||
|
|
f832c83a18 | ||
|
|
462a86cfcc | ||
|
|
8a130ec3f1 | ||
|
|
c26cd3b9fe | ||
|
|
9d7b515b26 | ||
|
|
f1f90807e4 | ||
|
|
5bb875a0fa | ||
|
|
9a88ed3cda | ||
|
|
8026c84c95 | ||
|
|
82059df324 | ||
|
|
23610db727 | ||
|
|
f984b8a091 | ||
|
|
4330bfd8ca | ||
|
|
5782496287 | ||
|
|
a0f2b5f591 | ||
|
|
0350faf75d | ||
|
|
9f951c8fb5 | ||
|
|
8276e0908a | ||
|
|
6539b591b6 | ||
|
|
014f1b841f | ||
|
|
b52afe8d42 | ||
|
|
f36869e97d | ||
|
|
78c6231c01 | ||
|
|
e75535d30b | ||
|
|
d8429c5c34 | ||
|
|
c3ed08c249 | ||
|
|
2f0b652dad | ||
|
|
d4214638a0 | ||
|
|
c962d29280 | ||
|
|
44af5be30f | ||
|
|
fe63a64b6e | ||
|
|
d31219ba89 | ||
|
|
756ce96da9 | ||
|
|
b64f5ffcb4 | ||
|
|
eb45310c8f | ||
|
|
d5dfed498b | ||
|
|
3fc89749c1 | ||
|
|
aecee361d0 |
4
.github/workflows/golang-test-darwin.yml
vendored
4
.github/workflows/golang-test-darwin.yml
vendored
@@ -6,6 +6,10 @@ on:
|
|||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
|
|||||||
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@@ -6,6 +6,10 @@ on:
|
|||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
@@ -66,13 +70,13 @@ jobs:
|
|||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
- name: Generate Iface Test bin
|
- name: Generate Iface Test bin
|
||||||
run: go test -c -o iface-testing.bin ./iface/...
|
run: go test -c -o iface-testing.bin ./iface/
|
||||||
|
|
||||||
- name: Generate RouteManager Test bin
|
- name: Generate RouteManager Test bin
|
||||||
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||||
|
|
||||||
- name: Generate Engine Test bin
|
- name: Generate Engine Test bin
|
||||||
run: go test -c -o engine-testing.bin ./client/internal/*.go
|
run: go test -c -o engine-testing.bin ./client/internal
|
||||||
|
|
||||||
- name: Generate Peer Test bin
|
- name: Generate Peer Test bin
|
||||||
run: go test -c -o peer-testing.bin ./client/internal/peer/...
|
run: go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||||
@@ -89,4 +93,4 @@ jobs:
|
|||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Peer tests in docker
|
- name: Run Peer tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||||
|
|||||||
60
.github/workflows/golang-test-windows.yml
vendored
60
.github/workflows/golang-test-windows.yml
vendored
@@ -6,47 +6,45 @@ on:
|
|||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
|
env:
|
||||||
|
downloadPath: '${{ github.workspace }}\temp'
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v2
|
|
||||||
- run: bash -x wireguard_nt.sh
|
|
||||||
working-directory: client
|
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v2
|
|
||||||
with:
|
|
||||||
name: syso
|
|
||||||
path: client/*.syso
|
|
||||||
retention-days: 1
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
needs: pre
|
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v4
|
||||||
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version: 1.19.x
|
go-version: 1.19.x
|
||||||
|
|
||||||
- uses: actions/cache@v2
|
- name: Download wintun
|
||||||
|
uses: carlosperate/download-file-action@v2
|
||||||
|
id: download-wintun
|
||||||
with:
|
with:
|
||||||
path: |
|
file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
|
||||||
%LocalAppData%\go-build
|
file-name: wintun.zip
|
||||||
~\go\pkg\mod
|
location: ${{ env.downloadPath }}
|
||||||
~\AppData\Local\go-build
|
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-go-
|
|
||||||
|
|
||||||
- uses: actions/download-artifact@v2
|
- name: Decompressing wintun files
|
||||||
with:
|
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||||
name: syso
|
|
||||||
path: iface\
|
|
||||||
|
|
||||||
- name: Test
|
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||||
run: go test -tags=load_wgnt_from_rsrc -timeout 5m -p 1 ./...
|
|
||||||
|
- run: choco install -y sysinternals
|
||||||
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
||||||
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||||
|
|
||||||
|
- name: test
|
||||||
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
||||||
|
- name: test output
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: Get-Content test-out.txt
|
||||||
3
.github/workflows/golangci-lint.yml
vendored
3
.github/workflows/golangci-lint.yml
vendored
@@ -1,5 +1,8 @@
|
|||||||
name: golangci-lint
|
name: golangci-lint
|
||||||
on: [pull_request]
|
on: [pull_request]
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
jobs:
|
jobs:
|
||||||
golangci:
|
golangci:
|
||||||
name: lint
|
name: lint
|
||||||
|
|||||||
60
.github/workflows/install-test-darwin.yml
vendored
Normal file
60
.github/workflows/install-test-darwin.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
name: Test installation Darwin
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "release_files/install.sh"
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
jobs:
|
||||||
|
install-cli-only:
|
||||||
|
runs-on: macos-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Rename brew package
|
||||||
|
if: ${{ matrix.check_bin_install }}
|
||||||
|
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||||
|
|
||||||
|
- name: Run install script
|
||||||
|
run: |
|
||||||
|
sh ./release_files/install.sh
|
||||||
|
env:
|
||||||
|
SKIP_UI_APP: true
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
if ! command -v netbird &> /dev/null; then
|
||||||
|
echo "Error: netbird is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
install-all:
|
||||||
|
runs-on: macos-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Rename brew package
|
||||||
|
if: ${{ matrix.check_bin_install }}
|
||||||
|
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||||
|
|
||||||
|
- name: Run install script
|
||||||
|
run: |
|
||||||
|
sh ./release_files/install.sh
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
if ! command -v netbird &> /dev/null; then
|
||||||
|
echo "Error: netbird is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
|
||||||
|
echo "Error: NetBird UI is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
38
.github/workflows/install-test-linux.yml
vendored
Normal file
38
.github/workflows/install-test-linux.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: Test installation Linux
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "release_files/install.sh"
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
jobs:
|
||||||
|
install-cli-only:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
check_bin_install: [true, false]
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Rename apt package
|
||||||
|
if: ${{ matrix.check_bin_install }}
|
||||||
|
run: |
|
||||||
|
sudo mv /usr/bin/apt /usr/bin/apt.bak
|
||||||
|
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
|
||||||
|
|
||||||
|
- name: Run install script
|
||||||
|
run: |
|
||||||
|
sh ./release_files/install.sh
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
if ! command -v netbird &> /dev/null; then
|
||||||
|
echo "Error: netbird is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
21
.github/workflows/release.yml
vendored
21
.github/workflows/release.yml
vendored
@@ -9,9 +9,13 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.4"
|
SIGN_PIPE_VER: "v0.0.6"
|
||||||
GORELEASER_VER: "v1.14.1"
|
GORELEASER_VER: "v1.14.1"
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -21,10 +25,6 @@ jobs:
|
|||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
|
|
||||||
- name: Generate syso with DLL
|
|
||||||
run: bash -x wireguard_nt.sh
|
|
||||||
working-directory: client
|
|
||||||
-
|
-
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
@@ -59,6 +59,17 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
- name: Install OS build dependencies
|
- name: Install OS build dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||||
|
|
||||||
|
- name: Install rsrc
|
||||||
|
run: go install github.com/akavel/rsrc@v0.10.2
|
||||||
|
- name: Generate windows rsrc amd64
|
||||||
|
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
|
||||||
|
- name: Generate windows rsrc arm64
|
||||||
|
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
|
||||||
|
- name: Generate windows rsrc arm
|
||||||
|
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
|
||||||
|
- name: Generate windows rsrc 386
|
||||||
|
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||||
-
|
-
|
||||||
name: Run GoReleaser
|
name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v2
|
uses: goreleaser/goreleaser-action@v2
|
||||||
|
|||||||
15
.github/workflows/test-docker-compose-linux.yml
vendored
15
.github/workflows/test-docker-compose-linux.yml
vendored
@@ -6,6 +6,10 @@ on:
|
|||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -59,6 +63,11 @@ jobs:
|
|||||||
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
||||||
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
||||||
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
||||||
|
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
||||||
|
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
||||||
|
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
|
||||||
|
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
|
||||||
|
|
||||||
run: |
|
run: |
|
||||||
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||||
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||||
@@ -68,6 +77,12 @@ jobs:
|
|||||||
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
||||||
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
||||||
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
||||||
|
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||||
|
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||||
|
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||||
|
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||||
|
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||||
|
grep UseIDToken management.json | grep false
|
||||||
|
|
||||||
- name: run docker compose up
|
- name: run docker compose up
|
||||||
working-directory: infrastructure_files
|
working-directory: infrastructure_files
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ builds:
|
|||||||
- goos: windows
|
- goos: windows
|
||||||
goarch: 386
|
goarch: 386
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
tags:
|
tags:
|
||||||
- load_wgnt_from_rsrc
|
- load_wgnt_from_rsrc
|
||||||
@@ -47,7 +47,7 @@ builds:
|
|||||||
- arm64
|
- arm64
|
||||||
- arm
|
- arm
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
|
|
||||||
- id: netbird-signal
|
- id: netbird-signal
|
||||||
@@ -61,7 +61,7 @@ builds:
|
|||||||
- arm64
|
- arm64
|
||||||
- arm
|
- arm
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
|
|
||||||
archives:
|
archives:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ builds:
|
|||||||
goarch:
|
goarch:
|
||||||
- amd64
|
- amd64
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
|
|
||||||
- id: netbird-ui-windows
|
- id: netbird-ui-windows
|
||||||
@@ -24,7 +24,7 @@ builds:
|
|||||||
goarch:
|
goarch:
|
||||||
- amd64
|
- amd64
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
- -H windowsgui
|
- -H windowsgui
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ builds:
|
|||||||
- hardfloat
|
- hardfloat
|
||||||
- softfloat
|
- softfloat
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||||
tags:
|
tags:
|
||||||
- load_wgnt_from_rsrc
|
- load_wgnt_from_rsrc
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>:hatching_chick: New Release! DNS support.</strong>
|
<strong>:hatching_chick: New Release! Peer expiration.</strong>
|
||||||
<a href="https://github.com/netbirdio/netbird/releases">
|
<a href="https://github.com/netbirdio/netbird/releases">
|
||||||
Learn more
|
Learn more
|
||||||
</a>
|
</a>
|
||||||
|
|||||||
129
client/android/client.go
Normal file
129
client/android/client.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionListener export internal Listener for mobile
|
||||||
|
type ConnectionListener interface {
|
||||||
|
peer.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
// TunAdapter export internal TunAdapter for mobile
|
||||||
|
type TunAdapter interface {
|
||||||
|
iface.TunAdapter
|
||||||
|
}
|
||||||
|
|
||||||
|
// IFaceDiscover export internal IFaceDiscover for mobile
|
||||||
|
type IFaceDiscover interface {
|
||||||
|
stdnet.ExternalIFaceDiscover
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client struct manage the life circle of background service
|
||||||
|
type Client struct {
|
||||||
|
cfgFile string
|
||||||
|
tunAdapter iface.TunAdapter
|
||||||
|
iFaceDiscover IFaceDiscover
|
||||||
|
recorder *peer.Status
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
ctxCancelLock *sync.Mutex
|
||||||
|
deviceName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient instantiate a new Client
|
||||||
|
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client {
|
||||||
|
lvl, _ := log.ParseLevel("trace")
|
||||||
|
log.SetLevel(lvl)
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
cfgFile: cfgFile,
|
||||||
|
deviceName: deviceName,
|
||||||
|
tunAdapter: tunAdapter,
|
||||||
|
iFaceDiscover: iFaceDiscover,
|
||||||
|
recorder: peer.NewRecorder(""),
|
||||||
|
ctxCancelLock: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run start the internal client. It is a blocker function
|
||||||
|
func (c *Client) Run(urlOpener URLOpener) error {
|
||||||
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
|
ConfigPath: c.cfgFile,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
//nolint
|
||||||
|
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
c.ctxCancelLock.Lock()
|
||||||
|
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||||
|
defer c.ctxCancel()
|
||||||
|
c.ctxCancelLock.Unlock()
|
||||||
|
|
||||||
|
auth := NewAuthWithConfig(ctx, cfg)
|
||||||
|
err = auth.login(urlOpener)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo do not throw error in case of cancelled context
|
||||||
|
ctx = internal.CtxInitState(ctx)
|
||||||
|
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop the internal client and free the resources
|
||||||
|
func (c *Client) Stop() {
|
||||||
|
c.ctxCancelLock.Lock()
|
||||||
|
defer c.ctxCancelLock.Unlock()
|
||||||
|
if c.ctxCancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.ctxCancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeersList return with the list of the PeerInfos
|
||||||
|
func (c *Client) PeersList() *PeerInfoArray {
|
||||||
|
|
||||||
|
fullStatus := c.recorder.GetFullStatus()
|
||||||
|
|
||||||
|
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
|
||||||
|
for n, p := range fullStatus.Peers {
|
||||||
|
pi := PeerInfo{
|
||||||
|
p.IP,
|
||||||
|
p.FQDN,
|
||||||
|
p.ConnStatus.String(),
|
||||||
|
p.Direct,
|
||||||
|
}
|
||||||
|
peerInfos[n] = pi
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PeerInfoArray{items: peerInfos}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConnectionListener set the network connection listener
|
||||||
|
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||||
|
c.recorder.SetConnectionListener(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveConnectionListener remove connection listener
|
||||||
|
func (c *Client) RemoveConnectionListener() {
|
||||||
|
c.recorder.RemoveConnectionListener()
|
||||||
|
}
|
||||||
229
client/android/login.go
Normal file
229
client/android/login.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/cmd"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSOListener is async listener for mobile framework
|
||||||
|
type SSOListener interface {
|
||||||
|
OnSuccess(bool)
|
||||||
|
OnError(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrListener is async listener for mobile framework
|
||||||
|
type ErrListener interface {
|
||||||
|
OnSuccess()
|
||||||
|
OnError(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||||
|
// the backend want to show an url for the user
|
||||||
|
type URLOpener interface {
|
||||||
|
Open(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth can register or login new client
|
||||||
|
type Auth struct {
|
||||||
|
ctx context.Context
|
||||||
|
config *internal.Config
|
||||||
|
cfgPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuth instantiate Auth struct and validate the management URL
|
||||||
|
func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||||
|
inputCfg := internal.ConfigInput{
|
||||||
|
ManagementURL: mgmURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.CreateInMemoryConfig(inputCfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Auth{
|
||||||
|
ctx: context.Background(),
|
||||||
|
config: cfg,
|
||||||
|
cfgPath: cfgPath,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthWithConfig instantiate Auth based on existing config
|
||||||
|
func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
|
||||||
|
return &Auth{
|
||||||
|
ctx: ctx,
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||||
|
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||||
|
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||||
|
func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||||
|
go func() {
|
||||||
|
sso, err := a.saveConfigIfSSOSupported()
|
||||||
|
if err != nil {
|
||||||
|
listener.OnError(err)
|
||||||
|
} else {
|
||||||
|
listener.OnSuccess(sso)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
|
supportsSSO := true
|
||||||
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
|
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
|
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
supportsSSO = false
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
if !supportsSSO {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = internal.WriteOutConfig(a.cfgPath, a.config)
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
|
||||||
|
func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
|
||||||
|
go func() {
|
||||||
|
err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
|
||||||
|
if err != nil {
|
||||||
|
resultListener.OnError(err)
|
||||||
|
} else {
|
||||||
|
resultListener.OnSuccess()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
|
//nolint
|
||||||
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
|
||||||
|
err := a.withBackOff(a.ctx, func() error {
|
||||||
|
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||||
|
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
|
// we got an answer from management, exit backoff earlier
|
||||||
|
return backoff.Permanent(backoffErr)
|
||||||
|
}
|
||||||
|
return backoffErr
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return internal.WriteOutConfig(a.cfgPath, a.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login try register the client on the server
|
||||||
|
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
||||||
|
go func() {
|
||||||
|
err := a.login(urlOpener)
|
||||||
|
if err != nil {
|
||||||
|
resultListener.OnError(err)
|
||||||
|
} else {
|
||||||
|
resultListener.OnSuccess()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) login(urlOpener URLOpener) error {
|
||||||
|
var needsLogin bool
|
||||||
|
|
||||||
|
// check if we need to generate JWT token
|
||||||
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
|
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||||
|
return
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtToken := ""
|
||||||
|
if needsLogin {
|
||||||
|
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
|
}
|
||||||
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.withBackOff(a.ctx, func() error {
|
||||||
|
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||||
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, error) {
|
||||||
|
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||||
|
if err != nil {
|
||||||
|
s, ok := gstatus.FromError(err)
|
||||||
|
if ok && s.Code() == codes.NotFound {
|
||||||
|
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||||
|
"If you are using hosting Netbird see documentation at " +
|
||||||
|
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||||
|
} else if ok && s.Code() == codes.Unimplemented {
|
||||||
|
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||||
|
"please update your servver or use Setup Keys to login", a.config.ManagementURL)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||||
|
|
||||||
|
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting a request device code failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||||
|
|
||||||
|
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||||
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
tokenInfo, err := hostedClient.WaitToken(waitCTX, flowInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
||||||
|
return backoff.RetryNotify(
|
||||||
|
bf,
|
||||||
|
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
||||||
|
func(err error, duration time.Duration) {
|
||||||
|
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
37
client/android/peer_notifier.go
Normal file
37
client/android/peer_notifier.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
|
type PeerInfo struct {
|
||||||
|
IP string
|
||||||
|
FQDN string
|
||||||
|
ConnStatus string // Todo replace to enum
|
||||||
|
Direct bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||||
|
type PeerInfoCollection interface {
|
||||||
|
Add(s string) PeerInfoCollection
|
||||||
|
Get(i int) string
|
||||||
|
Size() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerInfoArray is the implementation of the PeerInfoCollection
|
||||||
|
type PeerInfoArray struct {
|
||||||
|
items []PeerInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new PeerInfo to the collection
|
||||||
|
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
|
||||||
|
array.items = append(array.items, s)
|
||||||
|
return array
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get return an element of the collection
|
||||||
|
func (array PeerInfoArray) Get(i int) *PeerInfo {
|
||||||
|
return &array.items[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size return with the size of the collection
|
||||||
|
func (array PeerInfoArray) Size() int {
|
||||||
|
return len(array.items)
|
||||||
|
}
|
||||||
78
client/android/preferences.go
Normal file
78
client/android/preferences.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Preferences export a subset of the internal config for gomobile
|
||||||
|
type Preferences struct {
|
||||||
|
configInput internal.ConfigInput
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPreferences create new Preferences instance
|
||||||
|
func NewPreferences(configPath string) *Preferences {
|
||||||
|
ci := internal.ConfigInput{
|
||||||
|
ConfigPath: configPath,
|
||||||
|
}
|
||||||
|
return &Preferences{ci}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetManagementURL read url from config file
|
||||||
|
func (p *Preferences) GetManagementURL() (string, error) {
|
||||||
|
if p.configInput.ManagementURL != "" {
|
||||||
|
return p.configInput.ManagementURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cfg.ManagementURL.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetManagementURL store the given url and wait for commit
|
||||||
|
func (p *Preferences) SetManagementURL(url string) {
|
||||||
|
p.configInput.ManagementURL = url
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAdminURL read url from config file
|
||||||
|
func (p *Preferences) GetAdminURL() (string, error) {
|
||||||
|
if p.configInput.AdminURL != "" {
|
||||||
|
return p.configInput.AdminURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cfg.AdminURL.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAdminURL store the given url and wait for commit
|
||||||
|
func (p *Preferences) SetAdminURL(url string) {
|
||||||
|
p.configInput.AdminURL = url
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPreSharedKey read preshared key from config file
|
||||||
|
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||||
|
if p.configInput.PreSharedKey != nil {
|
||||||
|
return *p.configInput.PreSharedKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cfg.PreSharedKey, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPreSharedKey store the given key and wait for commit
|
||||||
|
func (p *Preferences) SetPreSharedKey(key string) {
|
||||||
|
p.configInput.PreSharedKey = &key
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit write out the changes into config file
|
||||||
|
func (p *Preferences) Commit() error {
|
||||||
|
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||||
|
return err
|
||||||
|
}
|
||||||
120
client/android/preferences_test.go
Normal file
120
client/android/preferences_test.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPreferences_DefaultValues(t *testing.T) {
|
||||||
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
|
p := NewPreferences(cfgFile)
|
||||||
|
defaultVar, err := p.GetAdminURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read default value: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultVar != internal.DefaultAdminURL {
|
||||||
|
t.Errorf("invalid default admin url: %s", defaultVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultVar, err = p.GetManagementURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read default management URL: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultVar != internal.DefaultManagementURL {
|
||||||
|
t.Errorf("invalid default management url: %s", defaultVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
var preSharedKey string
|
||||||
|
preSharedKey, err = p.GetPreSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read default preshared key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if preSharedKey != "" {
|
||||||
|
t.Errorf("invalid preshared key: %s", preSharedKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
||||||
|
exampleString := "exampleString"
|
||||||
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
|
p := NewPreferences(cfgFile)
|
||||||
|
|
||||||
|
p.SetAdminURL(exampleString)
|
||||||
|
resp, err := p.GetAdminURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read admin url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleString {
|
||||||
|
t.Errorf("unexpected admin url: %s", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.SetManagementURL(exampleString)
|
||||||
|
resp, err = p.GetManagementURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read managmenet url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleString {
|
||||||
|
t.Errorf("unexpected managemenet url: %s", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.SetPreSharedKey(exampleString)
|
||||||
|
resp, err = p.GetPreSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read preshared key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleString {
|
||||||
|
t.Errorf("unexpected preshared key: %s", resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreferences_Commit(t *testing.T) {
|
||||||
|
exampleURL := "https://myurl.com:443"
|
||||||
|
examplePresharedKey := "topsecret"
|
||||||
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
|
p := NewPreferences(cfgFile)
|
||||||
|
|
||||||
|
p.SetAdminURL(exampleURL)
|
||||||
|
p.SetManagementURL(exampleURL)
|
||||||
|
p.SetPreSharedKey(examplePresharedKey)
|
||||||
|
|
||||||
|
err := p.Commit()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to save changes: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p = NewPreferences(cfgFile)
|
||||||
|
resp, err := p.GetAdminURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read admin url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleURL {
|
||||||
|
t.Errorf("unexpected admin url: %s", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = p.GetManagementURL()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read managmenet url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != exampleURL {
|
||||||
|
t.Errorf("unexpected managemenet url: %s", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = p.GetPreSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read preshared key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != examplePresharedKey {
|
||||||
|
t.Errorf("unexpected preshared key: %s", resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,10 +3,11 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
@@ -38,7 +39,7 @@ var loginCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.GetConfig(internal.ConfigInput{
|
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
@@ -134,7 +135,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
jwtToken = tokenInfo.AccessToken
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
err = WithBackOff(func() error {
|
||||||
@@ -152,7 +153,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*internal.TokenInfo, error) {
|
||||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config)
|
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s, ok := gstatus.FromError(err)
|
s, ok := gstatus.FromError(err)
|
||||||
if ok && s.Code() == codes.NotFound {
|
if ok && s.Code() == codes.NotFound {
|
||||||
@@ -171,12 +172,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostedClient := internal.NewHostedDeviceFlow(
|
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||||
providerConfig.ProviderConfig.Audience,
|
|
||||||
providerConfig.ProviderConfig.ClientID,
|
|
||||||
providerConfig.ProviderConfig.TokenEndpoint,
|
|
||||||
providerConfig.ProviderConfig.DeviceAuthEndpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,15 +4,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -57,7 +59,7 @@ var sshCmd = &cobra.Command{
|
|||||||
|
|
||||||
ctx := internal.CtxInitState(cmd.Context())
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
config, err := internal.ReadConfig(internal.ConfigInput{
|
config, err := internal.UpdateConfig(internal.ConfigInput{
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -2,25 +2,74 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/spf13/cobra"
|
"github.com/netbirdio/netbird/version"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type peerStateDetailOutput struct {
|
||||||
|
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||||
|
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||||
|
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||||
|
Status string `json:"status" yaml:"status"`
|
||||||
|
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||||
|
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||||
|
Direct bool `json:"direct" yaml:"direct"`
|
||||||
|
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type peersStateOutput struct {
|
||||||
|
Total int `json:"total" yaml:"total"`
|
||||||
|
Connected int `json:"connected" yaml:"connected"`
|
||||||
|
Details []peerStateDetailOutput `json:"details" yaml:"details"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type signalStateOutput struct {
|
||||||
|
URL string `json:"url" yaml:"url"`
|
||||||
|
Connected bool `json:"connected" yaml:"connected"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type managementStateOutput struct {
|
||||||
|
URL string `json:"url" yaml:"url"`
|
||||||
|
Connected bool `json:"connected" yaml:"connected"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type iceCandidateType struct {
|
||||||
|
Local string `json:"local" yaml:"local"`
|
||||||
|
Remote string `json:"remote" yaml:"remote"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type statusOutputOverview struct {
|
||||||
|
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||||
|
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||||
|
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||||
|
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||||
|
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||||
|
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||||
|
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||||
|
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||||
|
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
detailFlag bool
|
detailFlag bool
|
||||||
ipv4Flag bool
|
ipv4Flag bool
|
||||||
|
jsonFlag bool
|
||||||
|
yamlFlag bool
|
||||||
ipsFilter []string
|
ipsFilter []string
|
||||||
statusFilter string
|
statusFilter string
|
||||||
ipsFilterMap map[string]struct{}
|
ipsFilterMap map[string]struct{}
|
||||||
@@ -29,67 +78,99 @@ var (
|
|||||||
var statusCmd = &cobra.Command{
|
var statusCmd = &cobra.Command{
|
||||||
Use: "status",
|
Use: "status",
|
||||||
Short: "status of the Netbird Service",
|
Short: "status of the Netbird Service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: statusFunc,
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
|
||||||
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
|
||||||
|
|
||||||
err := parseFilters()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = util.InitLog(logLevel, "console")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed initializing log %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := internal.CtxInitState(context.Background())
|
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
|
||||||
"If the daemon is not running please run: "+
|
|
||||||
"\nnetbird service install \nnetbird service start\n", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
daemonStatus := fmt.Sprintf("Daemon status: %s\n", resp.GetStatus())
|
|
||||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
|
||||||
|
|
||||||
cmd.Printf("%s\n"+
|
|
||||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
|
||||||
" netbird up \n\n"+
|
|
||||||
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
|
||||||
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
|
||||||
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
|
|
||||||
daemonStatus,
|
|
||||||
)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
pbFullStatus := resp.GetFullStatus()
|
|
||||||
fullStatus := fromProtoFullStatus(pbFullStatus)
|
|
||||||
|
|
||||||
cmd.Print(parseFullStatus(fullStatus, detailFlag, daemonStatus, resp.GetDaemonVersion(), ipv4Flag))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
ipsFilterMap = make(map[string]struct{})
|
ipsFilterMap = make(map[string]struct{})
|
||||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information")
|
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||||
|
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||||
|
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||||
|
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
|
||||||
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
|
err := parseFilters()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = util.InitLog(logLevel, "console")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed initializing log %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(context.Background())
|
||||||
|
|
||||||
|
resp, _ := getStatus(ctx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||||
|
cmd.Printf("Daemon status: %s\n\n"+
|
||||||
|
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||||
|
" netbird up \n\n"+
|
||||||
|
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
||||||
|
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
||||||
|
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
|
||||||
|
resp.GetStatus(),
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipv4Flag {
|
||||||
|
cmd.Print(parseInterfaceIP(resp.GetFullStatus().GetLocalPeerState().GetIP()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
outputInformationHolder := convertToStatusOutputOverview(resp)
|
||||||
|
|
||||||
|
statusOutputString := ""
|
||||||
|
switch {
|
||||||
|
case detailFlag:
|
||||||
|
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
||||||
|
case jsonFlag:
|
||||||
|
statusOutputString, err = parseToJSON(outputInformationHolder)
|
||||||
|
case yamlFlag:
|
||||||
|
statusOutputString, err = parseToYAML(outputInformationHolder)
|
||||||
|
default:
|
||||||
|
statusOutputString = parseGeneralSummary(outputInformationHolder, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Print(statusOutputString)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) {
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||||
|
"If the daemon is not running please run: "+
|
||||||
|
"\nnetbird service install \nnetbird service start\n", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "disconnected", "connected":
|
case "", "disconnected", "connected":
|
||||||
@@ -109,195 +190,229 @@ func parseFilters() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromProtoFullStatus(pbFullStatus *proto.FullStatus) nbStatus.FullStatus {
|
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
||||||
var fullStatus nbStatus.FullStatus
|
pbFullStatus := resp.GetFullStatus()
|
||||||
|
|
||||||
managementState := pbFullStatus.GetManagementState()
|
managementState := pbFullStatus.GetManagementState()
|
||||||
fullStatus.ManagementState.URL = managementState.GetURL()
|
managementOverview := managementStateOutput{
|
||||||
fullStatus.ManagementState.Connected = managementState.GetConnected()
|
URL: managementState.GetURL(),
|
||||||
|
Connected: managementState.GetConnected(),
|
||||||
signalState := pbFullStatus.GetSignalState()
|
|
||||||
fullStatus.SignalState.URL = signalState.GetURL()
|
|
||||||
fullStatus.SignalState.Connected = signalState.GetConnected()
|
|
||||||
|
|
||||||
localPeerState := pbFullStatus.GetLocalPeerState()
|
|
||||||
fullStatus.LocalPeerState.IP = localPeerState.GetIP()
|
|
||||||
fullStatus.LocalPeerState.PubKey = localPeerState.GetPubKey()
|
|
||||||
fullStatus.LocalPeerState.KernelInterface = localPeerState.GetKernelInterface()
|
|
||||||
fullStatus.LocalPeerState.FQDN = localPeerState.GetFqdn()
|
|
||||||
|
|
||||||
var peersState []nbStatus.PeerState
|
|
||||||
|
|
||||||
for _, pbPeerState := range pbFullStatus.GetPeers() {
|
|
||||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
|
||||||
peerState := nbStatus.PeerState{
|
|
||||||
IP: pbPeerState.GetIP(),
|
|
||||||
PubKey: pbPeerState.GetPubKey(),
|
|
||||||
ConnStatus: pbPeerState.GetConnStatus(),
|
|
||||||
ConnStatusUpdate: timeLocal,
|
|
||||||
Relayed: pbPeerState.GetRelayed(),
|
|
||||||
Direct: pbPeerState.GetDirect(),
|
|
||||||
LocalIceCandidateType: pbPeerState.GetLocalIceCandidateType(),
|
|
||||||
RemoteIceCandidateType: pbPeerState.GetRemoteIceCandidateType(),
|
|
||||||
FQDN: pbPeerState.GetFqdn(),
|
|
||||||
}
|
|
||||||
peersState = append(peersState, peerState)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fullStatus.Peers = peersState
|
signalState := pbFullStatus.GetSignalState()
|
||||||
|
signalOverview := signalStateOutput{
|
||||||
|
URL: signalState.GetURL(),
|
||||||
|
Connected: signalState.GetConnected(),
|
||||||
|
}
|
||||||
|
|
||||||
return fullStatus
|
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
||||||
|
|
||||||
|
overview := statusOutputOverview{
|
||||||
|
Peers: peersOverview,
|
||||||
|
CliVersion: version.NetbirdVersion(),
|
||||||
|
DaemonVersion: resp.GetDaemonVersion(),
|
||||||
|
ManagementState: managementOverview,
|
||||||
|
SignalState: signalOverview,
|
||||||
|
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||||
|
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||||
|
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||||
|
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return overview
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFullStatus(fullStatus nbStatus.FullStatus, printDetail bool, daemonStatus string, daemonVersion string, flag bool) string {
|
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||||
|
var peersStateDetail []peerStateDetailOutput
|
||||||
|
localICE := ""
|
||||||
|
remoteICE := ""
|
||||||
|
connType := ""
|
||||||
|
peersConnected := 0
|
||||||
|
for _, pbPeerState := range peers {
|
||||||
|
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||||
|
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isPeerConnected {
|
||||||
|
peersConnected = peersConnected + 1
|
||||||
|
|
||||||
interfaceIP := fullStatus.LocalPeerState.IP
|
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||||
|
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||||
|
connType = "P2P"
|
||||||
|
if pbPeerState.Relayed {
|
||||||
|
connType = "Relayed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
||||||
|
peerState := peerStateDetailOutput{
|
||||||
|
IP: pbPeerState.GetIP(),
|
||||||
|
PubKey: pbPeerState.GetPubKey(),
|
||||||
|
Status: pbPeerState.GetConnStatus(),
|
||||||
|
LastStatusUpdate: timeLocal.UTC(),
|
||||||
|
ConnType: connType,
|
||||||
|
Direct: pbPeerState.GetDirect(),
|
||||||
|
IceCandidateType: iceCandidateType{
|
||||||
|
Local: localICE,
|
||||||
|
Remote: remoteICE,
|
||||||
|
},
|
||||||
|
FQDN: pbPeerState.GetFqdn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
peersStateDetail = append(peersStateDetail, peerState)
|
||||||
|
}
|
||||||
|
|
||||||
|
sortPeersByIP(peersStateDetail)
|
||||||
|
|
||||||
|
peersOverview := peersStateOutput{
|
||||||
|
Total: len(peersStateDetail),
|
||||||
|
Connected: peersConnected,
|
||||||
|
Details: peersStateDetail,
|
||||||
|
}
|
||||||
|
return peersOverview
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
|
||||||
|
if len(peersStateDetail) > 0 {
|
||||||
|
sort.SliceStable(peersStateDetail, func(i, j int) bool {
|
||||||
|
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
|
||||||
|
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
|
||||||
|
return iAddr.Compare(jAddr) == -1
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInterfaceIP(interfaceIP string) string {
|
||||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
return fmt.Sprintf("%s\n", ip)
|
||||||
|
}
|
||||||
|
|
||||||
if ipv4Flag {
|
func parseToJSON(overview statusOutputOverview) (string, error) {
|
||||||
return fmt.Sprintf("%s\n", ip)
|
jsonBytes, err := json.Marshal(overview)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("json marshal failed")
|
||||||
}
|
}
|
||||||
|
return string(jsonBytes), err
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
func parseToYAML(overview statusOutputOverview) (string, error) {
|
||||||
managementStatusURL = ""
|
yamlBytes, err := yaml.Marshal(overview)
|
||||||
signalStatusURL = ""
|
if err != nil {
|
||||||
managementConnString = "Disconnected"
|
return "", fmt.Errorf("yaml marshal failed")
|
||||||
signalConnString = "Disconnected"
|
|
||||||
interfaceTypeString = "Userspace"
|
|
||||||
)
|
|
||||||
|
|
||||||
if printDetail {
|
|
||||||
managementStatusURL = fmt.Sprintf(" to %s", fullStatus.ManagementState.URL)
|
|
||||||
signalStatusURL = fmt.Sprintf(" to %s", fullStatus.SignalState.URL)
|
|
||||||
}
|
}
|
||||||
|
return string(yamlBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
if fullStatus.ManagementState.Connected {
|
func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
||||||
|
|
||||||
|
managementConnString := "Disconnected"
|
||||||
|
if overview.ManagementState.Connected {
|
||||||
managementConnString = "Connected"
|
managementConnString = "Connected"
|
||||||
|
if showURL {
|
||||||
|
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if fullStatus.SignalState.Connected {
|
signalConnString := "Disconnected"
|
||||||
|
if overview.SignalState.Connected {
|
||||||
signalConnString = "Connected"
|
signalConnString = "Connected"
|
||||||
|
if showURL {
|
||||||
|
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if fullStatus.LocalPeerState.KernelInterface {
|
interfaceTypeString := "Userspace"
|
||||||
|
interfaceIP := overview.IP
|
||||||
|
if overview.KernelInterface {
|
||||||
interfaceTypeString = "Kernel"
|
interfaceTypeString = "Kernel"
|
||||||
} else if fullStatus.LocalPeerState.IP == "" {
|
} else if overview.IP == "" {
|
||||||
interfaceTypeString = "N/A"
|
interfaceTypeString = "N/A"
|
||||||
interfaceIP = "N/A"
|
interfaceIP = "N/A"
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedPeersString, peersConnected := parsePeers(fullStatus.Peers, printDetail)
|
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||||
|
|
||||||
peersCountString := fmt.Sprintf("%d/%d Connected", peersConnected, len(fullStatus.Peers))
|
|
||||||
|
|
||||||
summary := fmt.Sprintf(
|
summary := fmt.Sprintf(
|
||||||
"Daemon version: %s\n"+
|
"Daemon version: %s\n"+
|
||||||
"CLI version: %s\n"+
|
"CLI version: %s\n"+
|
||||||
"%s"+ // daemon status
|
"Management: %s\n"+
|
||||||
"Management: %s%s\n"+
|
"Signal: %s\n"+
|
||||||
"Signal: %s%s\n"+
|
"FQDN: %s\n"+
|
||||||
"Domain: %s\n"+
|
|
||||||
"NetBird IP: %s\n"+
|
"NetBird IP: %s\n"+
|
||||||
"Interface type: %s\n"+
|
"Interface type: %s\n"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
daemonVersion,
|
overview.DaemonVersion,
|
||||||
system.NetbirdVersion(),
|
version.NetbirdVersion(),
|
||||||
daemonStatus,
|
|
||||||
managementConnString,
|
managementConnString,
|
||||||
managementStatusURL,
|
|
||||||
signalConnString,
|
signalConnString,
|
||||||
signalStatusURL,
|
overview.FQDN,
|
||||||
fullStatus.LocalPeerState.FQDN,
|
|
||||||
interfaceIP,
|
interfaceIP,
|
||||||
interfaceTypeString,
|
interfaceTypeString,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
)
|
)
|
||||||
|
|
||||||
if printDetail {
|
|
||||||
return fmt.Sprintf(
|
|
||||||
"Peers detail:"+
|
|
||||||
"%s\n"+
|
|
||||||
"%s",
|
|
||||||
parsedPeersString,
|
|
||||||
summary,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return summary
|
return summary
|
||||||
}
|
}
|
||||||
|
|
||||||
func parsePeers(peers []nbStatus.PeerState, printDetail bool) (string, int) {
|
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||||
var (
|
parsedPeersString := parsePeers(overview.Peers)
|
||||||
peersString = ""
|
summary := parseGeneralSummary(overview, true)
|
||||||
peersConnected = 0
|
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"Peers detail:"+
|
||||||
|
"%s\n"+
|
||||||
|
"%s",
|
||||||
|
parsedPeersString,
|
||||||
|
summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(peers) > 0 {
|
|
||||||
sort.SliceStable(peers, func(i, j int) bool {
|
|
||||||
iAddr, _ := netip.ParseAddr(peers[i].IP)
|
|
||||||
jAddr, _ := netip.ParseAddr(peers[j].IP)
|
|
||||||
return iAddr.Compare(jAddr) == -1
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
connectedStatusString := peer.StatusConnected.String()
|
|
||||||
|
|
||||||
for _, peerState := range peers {
|
|
||||||
peerConnectionStatus := false
|
|
||||||
if peerState.ConnStatus == connectedStatusString {
|
|
||||||
peersConnected = peersConnected + 1
|
|
||||||
peerConnectionStatus = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if printDetail {
|
|
||||||
|
|
||||||
if skipDetailByFilters(peerState, peerConnectionStatus) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
localICE := "-"
|
|
||||||
remoteICE := "-"
|
|
||||||
connType := "-"
|
|
||||||
|
|
||||||
if peerConnectionStatus {
|
|
||||||
localICE = peerState.LocalIceCandidateType
|
|
||||||
remoteICE = peerState.RemoteIceCandidateType
|
|
||||||
connType = "P2P"
|
|
||||||
if peerState.Relayed {
|
|
||||||
connType = "Relayed"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peerString := fmt.Sprintf(
|
|
||||||
"\n %s:\n"+
|
|
||||||
" NetBird IP: %s\n"+
|
|
||||||
" Public key: %s\n"+
|
|
||||||
" Status: %s\n"+
|
|
||||||
" -- detail --\n"+
|
|
||||||
" Connection type: %s\n"+
|
|
||||||
" Direct: %t\n"+
|
|
||||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
|
||||||
" Last connection update: %s\n",
|
|
||||||
peerState.FQDN,
|
|
||||||
peerState.IP,
|
|
||||||
peerState.PubKey,
|
|
||||||
peerState.ConnStatus,
|
|
||||||
connType,
|
|
||||||
peerState.Direct,
|
|
||||||
localICE,
|
|
||||||
remoteICE,
|
|
||||||
peerState.ConnStatusUpdate.Format("2006-01-02 15:04:05"),
|
|
||||||
)
|
|
||||||
|
|
||||||
peersString = peersString + peerString
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return peersString, peersConnected
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func skipDetailByFilters(peerState nbStatus.PeerState, isConnected bool) bool {
|
func parsePeers(peers peersStateOutput) string {
|
||||||
|
var (
|
||||||
|
peersString = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, peerState := range peers.Details {
|
||||||
|
|
||||||
|
localICE := "-"
|
||||||
|
if peerState.IceCandidateType.Local != "" {
|
||||||
|
localICE = peerState.IceCandidateType.Local
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteICE := "-"
|
||||||
|
if peerState.IceCandidateType.Remote != "" {
|
||||||
|
remoteICE = peerState.IceCandidateType.Remote
|
||||||
|
}
|
||||||
|
|
||||||
|
peerString := fmt.Sprintf(
|
||||||
|
"\n %s:\n"+
|
||||||
|
" NetBird IP: %s\n"+
|
||||||
|
" Public key: %s\n"+
|
||||||
|
" Status: %s\n"+
|
||||||
|
" -- detail --\n"+
|
||||||
|
" Connection type: %s\n"+
|
||||||
|
" Direct: %t\n"+
|
||||||
|
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||||
|
" Last connection update: %s\n",
|
||||||
|
peerState.FQDN,
|
||||||
|
peerState.IP,
|
||||||
|
peerState.PubKey,
|
||||||
|
peerState.Status,
|
||||||
|
peerState.ConnType,
|
||||||
|
peerState.Direct,
|
||||||
|
localICE,
|
||||||
|
remoteICE,
|
||||||
|
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
|
||||||
|
)
|
||||||
|
|
||||||
|
peersString = peersString + peerString
|
||||||
|
}
|
||||||
|
return peersString
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||||
statusEval := false
|
statusEval := false
|
||||||
ipEval := false
|
ipEval := false
|
||||||
|
|
||||||
|
|||||||
301
client/cmd/status_test.go
Normal file
301
client/cmd/status_test.go
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
var resp = &proto.StatusResponse{
|
||||||
|
Status: "Connected",
|
||||||
|
FullStatus: &proto.FullStatus{
|
||||||
|
Peers: []*proto.PeerState{
|
||||||
|
{
|
||||||
|
IP: "192.168.178.101",
|
||||||
|
PubKey: "Pubkey1",
|
||||||
|
Fqdn: "peer-1.awesome-domain.com",
|
||||||
|
ConnStatus: "Connected",
|
||||||
|
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||||
|
Relayed: false,
|
||||||
|
Direct: true,
|
||||||
|
LocalIceCandidateType: "",
|
||||||
|
RemoteIceCandidateType: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: "192.168.178.102",
|
||||||
|
PubKey: "Pubkey2",
|
||||||
|
Fqdn: "peer-2.awesome-domain.com",
|
||||||
|
ConnStatus: "Connected",
|
||||||
|
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||||
|
Relayed: true,
|
||||||
|
Direct: false,
|
||||||
|
LocalIceCandidateType: "relay",
|
||||||
|
RemoteIceCandidateType: "prflx",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ManagementState: &proto.ManagementState{
|
||||||
|
URL: "my-awesome-management.com:443",
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
SignalState: &proto.SignalState{
|
||||||
|
URL: "my-awesome-signal.com:443",
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
LocalPeerState: &proto.LocalPeerState{
|
||||||
|
IP: "192.168.178.100/16",
|
||||||
|
PubKey: "Some-Pub-Key",
|
||||||
|
KernelInterface: true,
|
||||||
|
Fqdn: "some-localhost.awesome-domain.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DaemonVersion: "0.14.1",
|
||||||
|
}
|
||||||
|
|
||||||
|
var overview = statusOutputOverview{
|
||||||
|
Peers: peersStateOutput{
|
||||||
|
Total: 2,
|
||||||
|
Connected: 2,
|
||||||
|
Details: []peerStateDetailOutput{
|
||||||
|
{
|
||||||
|
IP: "192.168.178.101",
|
||||||
|
PubKey: "Pubkey1",
|
||||||
|
FQDN: "peer-1.awesome-domain.com",
|
||||||
|
Status: "Connected",
|
||||||
|
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
|
||||||
|
ConnType: "P2P",
|
||||||
|
Direct: true,
|
||||||
|
IceCandidateType: iceCandidateType{
|
||||||
|
Local: "",
|
||||||
|
Remote: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: "192.168.178.102",
|
||||||
|
PubKey: "Pubkey2",
|
||||||
|
FQDN: "peer-2.awesome-domain.com",
|
||||||
|
Status: "Connected",
|
||||||
|
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
|
||||||
|
ConnType: "Relayed",
|
||||||
|
Direct: false,
|
||||||
|
IceCandidateType: iceCandidateType{
|
||||||
|
Local: "relay",
|
||||||
|
Remote: "prflx",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CliVersion: version.NetbirdVersion(),
|
||||||
|
DaemonVersion: "0.14.1",
|
||||||
|
ManagementState: managementStateOutput{
|
||||||
|
URL: "my-awesome-management.com:443",
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
SignalState: signalStateOutput{
|
||||||
|
URL: "my-awesome-signal.com:443",
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
IP: "192.168.178.100/16",
|
||||||
|
PubKey: "Some-Pub-Key",
|
||||||
|
KernelInterface: true,
|
||||||
|
FQDN: "some-localhost.awesome-domain.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||||
|
convertedResult := convertToStatusOutputOverview(resp)
|
||||||
|
|
||||||
|
assert.Equal(t, overview, convertedResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortingOfPeers(t *testing.T) {
|
||||||
|
peers := []peerStateDetailOutput{
|
||||||
|
{
|
||||||
|
IP: "192.168.178.104",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: "192.168.178.102",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: "192.168.178.101",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: "192.168.178.105",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: "192.168.178.103",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sortPeersByIP(peers)
|
||||||
|
|
||||||
|
assert.Equal(t, peers[3].IP, "192.168.178.104")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsingToJSON(t *testing.T) {
|
||||||
|
json, _ := parseToJSON(overview)
|
||||||
|
|
||||||
|
//@formatter:off
|
||||||
|
expectedJSON := "{\"" +
|
||||||
|
"peers\":" +
|
||||||
|
"{" +
|
||||||
|
"\"total\":2," +
|
||||||
|
"\"connected\":2," +
|
||||||
|
"\"details\":" +
|
||||||
|
"[" +
|
||||||
|
"{" +
|
||||||
|
"\"fqdn\":\"peer-1.awesome-domain.com\"," +
|
||||||
|
"\"netbirdIp\":\"192.168.178.101\"," +
|
||||||
|
"\"publicKey\":\"Pubkey1\"," +
|
||||||
|
"\"status\":\"Connected\"," +
|
||||||
|
"\"lastStatusUpdate\":\"2001-01-01T01:01:01Z\"," +
|
||||||
|
"\"connectionType\":\"P2P\"," +
|
||||||
|
"\"direct\":true," +
|
||||||
|
"\"iceCandidateType\":" +
|
||||||
|
"{" +
|
||||||
|
"\"local\":\"\"," +
|
||||||
|
"\"remote\":\"\"" +
|
||||||
|
"}" +
|
||||||
|
"}," +
|
||||||
|
"{" +
|
||||||
|
"\"fqdn\":\"peer-2.awesome-domain.com\"," +
|
||||||
|
"\"netbirdIp\":\"192.168.178.102\"," +
|
||||||
|
"\"publicKey\":\"Pubkey2\"," +
|
||||||
|
"\"status\":\"Connected\"," +
|
||||||
|
"\"lastStatusUpdate\":\"2002-02-02T02:02:02Z\"," +
|
||||||
|
"\"connectionType\":\"Relayed\"," +
|
||||||
|
"\"direct\":false," +
|
||||||
|
"\"iceCandidateType\":" +
|
||||||
|
"{" +
|
||||||
|
"\"local\":\"relay\"," +
|
||||||
|
"\"remote\":\"prflx\"" +
|
||||||
|
"}" +
|
||||||
|
"}" +
|
||||||
|
"]" +
|
||||||
|
"}," +
|
||||||
|
"\"cliVersion\":\"development\"," +
|
||||||
|
"\"daemonVersion\":\"0.14.1\"," +
|
||||||
|
"\"management\":" +
|
||||||
|
"{" +
|
||||||
|
"\"url\":\"my-awesome-management.com:443\"," +
|
||||||
|
"\"connected\":true" +
|
||||||
|
"}," +
|
||||||
|
"\"signal\":" +
|
||||||
|
"{\"" +
|
||||||
|
"url\":\"my-awesome-signal.com:443\"," +
|
||||||
|
"\"connected\":true" +
|
||||||
|
"}," +
|
||||||
|
"\"netbirdIp\":\"192.168.178.100/16\"," +
|
||||||
|
"\"publicKey\":\"Some-Pub-Key\"," +
|
||||||
|
"\"usesKernelInterface\":true," +
|
||||||
|
"\"fqdn\":\"some-localhost.awesome-domain.com\"" +
|
||||||
|
"}"
|
||||||
|
// @formatter:on
|
||||||
|
|
||||||
|
assert.Equal(t, expectedJSON, json)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsingToYAML(t *testing.T) {
|
||||||
|
yaml, _ := parseToYAML(overview)
|
||||||
|
|
||||||
|
expectedYAML := "peers:\n" +
|
||||||
|
" total: 2\n" +
|
||||||
|
" connected: 2\n" +
|
||||||
|
" details:\n" +
|
||||||
|
" - fqdn: peer-1.awesome-domain.com\n" +
|
||||||
|
" netbirdIp: 192.168.178.101\n" +
|
||||||
|
" publicKey: Pubkey1\n" +
|
||||||
|
" status: Connected\n" +
|
||||||
|
" lastStatusUpdate: 2001-01-01T01:01:01Z\n" +
|
||||||
|
" connectionType: P2P\n" +
|
||||||
|
" direct: true\n" +
|
||||||
|
" iceCandidateType:\n" +
|
||||||
|
" local: \"\"\n" +
|
||||||
|
" remote: \"\"\n" +
|
||||||
|
" - fqdn: peer-2.awesome-domain.com\n" +
|
||||||
|
" netbirdIp: 192.168.178.102\n" +
|
||||||
|
" publicKey: Pubkey2\n" +
|
||||||
|
" status: Connected\n" +
|
||||||
|
" lastStatusUpdate: 2002-02-02T02:02:02Z\n" +
|
||||||
|
" connectionType: Relayed\n" +
|
||||||
|
" direct: false\n" +
|
||||||
|
" iceCandidateType:\n" +
|
||||||
|
" local: relay\n" +
|
||||||
|
" remote: prflx\n" +
|
||||||
|
"cliVersion: development\n" +
|
||||||
|
"daemonVersion: 0.14.1\n" +
|
||||||
|
"management:\n" +
|
||||||
|
" url: my-awesome-management.com:443\n" +
|
||||||
|
" connected: true\n" +
|
||||||
|
"signal:\n" +
|
||||||
|
" url: my-awesome-signal.com:443\n" +
|
||||||
|
" connected: true\n" +
|
||||||
|
"netbirdIp: 192.168.178.100/16\n" +
|
||||||
|
"publicKey: Some-Pub-Key\n" +
|
||||||
|
"usesKernelInterface: true\n" +
|
||||||
|
"fqdn: some-localhost.awesome-domain.com\n"
|
||||||
|
|
||||||
|
assert.Equal(t, expectedYAML, yaml)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsingToDetail(t *testing.T) {
|
||||||
|
detail := parseToFullDetailSummary(overview)
|
||||||
|
|
||||||
|
expectedDetail := "Peers detail:\n" +
|
||||||
|
" peer-1.awesome-domain.com:\n" +
|
||||||
|
" NetBird IP: 192.168.178.101\n" +
|
||||||
|
" Public key: Pubkey1\n" +
|
||||||
|
" Status: Connected\n" +
|
||||||
|
" -- detail --\n" +
|
||||||
|
" Connection type: P2P\n" +
|
||||||
|
" Direct: true\n" +
|
||||||
|
" ICE candidate (Local/Remote): -/-\n" +
|
||||||
|
" Last connection update: 2001-01-01 01:01:01\n" +
|
||||||
|
"\n" +
|
||||||
|
" peer-2.awesome-domain.com:\n" +
|
||||||
|
" NetBird IP: 192.168.178.102\n" +
|
||||||
|
" Public key: Pubkey2\n" +
|
||||||
|
" Status: Connected\n" +
|
||||||
|
" -- detail --\n" +
|
||||||
|
" Connection type: Relayed\n" +
|
||||||
|
" Direct: false\n" +
|
||||||
|
" ICE candidate (Local/Remote): relay/prflx\n" +
|
||||||
|
" Last connection update: 2002-02-02 02:02:02\n" +
|
||||||
|
"\n" +
|
||||||
|
"Daemon version: 0.14.1\n" +
|
||||||
|
"CLI version: development\n" +
|
||||||
|
"Management: Connected to my-awesome-management.com:443\n" +
|
||||||
|
"Signal: Connected to my-awesome-signal.com:443\n" +
|
||||||
|
"FQDN: some-localhost.awesome-domain.com\n" +
|
||||||
|
"NetBird IP: 192.168.178.100/16\n" +
|
||||||
|
"Interface type: Kernel\n" +
|
||||||
|
"Peers count: 2/2 Connected\n"
|
||||||
|
|
||||||
|
assert.Equal(t, expectedDetail, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsingToShortVersion(t *testing.T) {
|
||||||
|
shortVersion := parseGeneralSummary(overview, false)
|
||||||
|
|
||||||
|
expectedString := "Daemon version: 0.14.1\n" +
|
||||||
|
"CLI version: development\n" +
|
||||||
|
"Management: Connected\n" +
|
||||||
|
"Signal: Connected\n" +
|
||||||
|
"FQDN: some-localhost.awesome-domain.com\n" +
|
||||||
|
"NetBird IP: 192.168.178.100/16\n" +
|
||||||
|
"Interface type: Kernel\n" +
|
||||||
|
"Peers count: 2/2 Connected\n"
|
||||||
|
|
||||||
|
assert.Equal(t, expectedString, shortVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsingOfIP(t *testing.T) {
|
||||||
|
InterfaceIP := "192.168.178.123/16"
|
||||||
|
|
||||||
|
parsedIP := parseInterfaceIP(InterfaceIP)
|
||||||
|
|
||||||
|
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
||||||
|
}
|
||||||
@@ -3,17 +3,19 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"net"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"net/netip"
|
||||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
"strings"
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"strings"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -70,7 +72,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.GetConfig(internal.ConfigInput{
|
config, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
@@ -92,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, nbStatus.NewRecorder())
|
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -11,7 +12,7 @@ var (
|
|||||||
Short: "prints Netbird version",
|
Short: "prints Netbird version",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
cmd.Println(system.NetbirdVersion())
|
cmd.Println(version.NetbirdVersion())
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
57
client/firewall/firewall.go
Normal file
57
client/firewall/firewall.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rule abstraction should be implemented by each firewall manager
|
||||||
|
//
|
||||||
|
// Each firewall type for different OS can use different type
|
||||||
|
// of the properties to hold data of the created rule
|
||||||
|
type Rule interface {
|
||||||
|
// GetRuleID returns the rule id
|
||||||
|
GetRuleID() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direction is the direction of the traffic
|
||||||
|
type Direction int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DirectionSrc is the direction of the traffic from the source
|
||||||
|
DirectionSrc Direction = iota
|
||||||
|
// DirectionDst is the direction of the traffic from the destination
|
||||||
|
DirectionDst
|
||||||
|
)
|
||||||
|
|
||||||
|
// Action is the action to be taken on a rule
|
||||||
|
type Action int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ActionAccept is the action to accept a packet
|
||||||
|
ActionAccept Action = iota
|
||||||
|
// ActionDrop is the action to drop a packet
|
||||||
|
ActionDrop
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager is the high level abstraction of a firewall manager
|
||||||
|
//
|
||||||
|
// It declares methods which handle actions required by the
|
||||||
|
// Netbird client for ACL and routing functionality
|
||||||
|
type Manager interface {
|
||||||
|
// AddFiltering rule to the firewall
|
||||||
|
AddFiltering(
|
||||||
|
ip net.IP,
|
||||||
|
port *Port,
|
||||||
|
direction Direction,
|
||||||
|
action Action,
|
||||||
|
comment string,
|
||||||
|
) (Rule, error)
|
||||||
|
|
||||||
|
// DeleteRule from the firewall by rule definition
|
||||||
|
DeleteRule(rule Rule) error
|
||||||
|
|
||||||
|
// Reset firewall to the default state
|
||||||
|
Reset() error
|
||||||
|
|
||||||
|
// TODO: migrate routemanager firewal actions to this interface
|
||||||
|
}
|
||||||
160
client/firewall/iptables/manager_linux.go
Normal file
160
client/firewall/iptables/manager_linux.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ChainFilterName is the name of the chain that is used for filtering by the Netbird client
|
||||||
|
ChainFilterName = "NETBIRD-ACL"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager of iptables firewall
|
||||||
|
type Manager struct {
|
||||||
|
mutex sync.Mutex
|
||||||
|
|
||||||
|
ipv4Client *iptables.IPTables
|
||||||
|
ipv6Client *iptables.IPTables
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create iptables firewall manager
|
||||||
|
func Create() (*Manager, error) {
|
||||||
|
m := &Manager{}
|
||||||
|
|
||||||
|
// init clients for booth ipv4 and ipv6
|
||||||
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||||
|
}
|
||||||
|
m.ipv4Client = ipv4Client
|
||||||
|
|
||||||
|
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ip6tables is not installed in the system or not supported")
|
||||||
|
}
|
||||||
|
m.ipv6Client = ipv6Client
|
||||||
|
|
||||||
|
if err := m.Reset(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to reset firewall: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFiltering rule to the firewall
|
||||||
|
func (m *Manager) AddFiltering(
|
||||||
|
ip net.IP,
|
||||||
|
port *fw.Port,
|
||||||
|
direction fw.Direction,
|
||||||
|
action fw.Action,
|
||||||
|
comment string,
|
||||||
|
) (fw.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
client := m.client(ip)
|
||||||
|
ok, err := client.ChainExists("filter", ChainFilterName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if chain exists: %s", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
if err := client.NewChain("filter", ChainFilterName); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create chain: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) {
|
||||||
|
return nil, fmt.Errorf("invalid port definition")
|
||||||
|
}
|
||||||
|
pv := strconv.Itoa(port.Values[0])
|
||||||
|
if port.IsRange {
|
||||||
|
pv += ":" + strconv.Itoa(port.Values[1])
|
||||||
|
}
|
||||||
|
specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment)
|
||||||
|
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rule := &Rule{
|
||||||
|
id: uuid.New().String(),
|
||||||
|
specs: specs,
|
||||||
|
v6: ip.To4() == nil,
|
||||||
|
}
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRule from the firewall by rule definition
|
||||||
|
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
r, ok := rule.(*Rule)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid rule type")
|
||||||
|
}
|
||||||
|
client := m.ipv4Client
|
||||||
|
if r.v6 {
|
||||||
|
client = m.ipv6Client
|
||||||
|
}
|
||||||
|
return client.Delete("filter", ChainFilterName, r.specs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset firewall to the default state
|
||||||
|
func (m *Manager) Reset() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
if err := m.reset(m.ipv4Client, "filter", ChainFilterName); err != nil {
|
||||||
|
return fmt.Errorf("clean ipv4 firewall ACL chain: %w", err)
|
||||||
|
}
|
||||||
|
if err := m.reset(m.ipv6Client, "filter", ChainFilterName); err != nil {
|
||||||
|
return fmt.Errorf("clean ipv6 firewall ACL chain: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset firewall chain, clear it and drop it
|
||||||
|
func (m *Manager) reset(client *iptables.IPTables, table, chain string) error {
|
||||||
|
ok, err := client.ChainExists(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check if chain exists: %w", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := client.ClearChain(table, ChainFilterName); err != nil {
|
||||||
|
return fmt.Errorf("failed to clear chain: %w", err)
|
||||||
|
}
|
||||||
|
return client.DeleteChain(table, ChainFilterName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
|
func (m *Manager) filterRuleSpecs(
|
||||||
|
table string, chain string, ip net.IP, port string,
|
||||||
|
direction fw.Direction, action fw.Action, comment string,
|
||||||
|
) (specs []string) {
|
||||||
|
if direction == fw.DirectionSrc {
|
||||||
|
specs = append(specs, "-s", ip.String())
|
||||||
|
}
|
||||||
|
specs = append(specs, "-p", "tcp", "--dport", port)
|
||||||
|
specs = append(specs, "-j", m.actionToStr(action))
|
||||||
|
return append(specs, "-m", "comment", "--comment", comment)
|
||||||
|
}
|
||||||
|
|
||||||
|
// client returns corresponding iptables client for the given ip
|
||||||
|
func (m *Manager) client(ip net.IP) *iptables.IPTables {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
return m.ipv4Client
|
||||||
|
}
|
||||||
|
return m.ipv6Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) actionToStr(action fw.Action) string {
|
||||||
|
if action == fw.ActionAccept {
|
||||||
|
return "ACCEPT"
|
||||||
|
}
|
||||||
|
return "DROP"
|
||||||
|
}
|
||||||
105
client/firewall/iptables/manager_linux_test.go
Normal file
105
client/firewall/iptables/manager_linux_test.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewManager(t *testing.T) {
|
||||||
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var rule1 fw.Rule
|
||||||
|
t.Run("add first rule", func(t *testing.T) {
|
||||||
|
ip := net.ParseIP("10.20.0.2")
|
||||||
|
port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}}
|
||||||
|
rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to add rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, true, rule1.(*Rule).specs...)
|
||||||
|
})
|
||||||
|
|
||||||
|
var rule2 fw.Rule
|
||||||
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
|
ip := net.ParseIP("10.20.0.3")
|
||||||
|
port := &fw.Port{
|
||||||
|
Proto: fw.PortProtocolTCP,
|
||||||
|
Values: []int{8043: 8046},
|
||||||
|
}
|
||||||
|
rule2, err = manager.AddFiltering(
|
||||||
|
ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to add rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, true, rule2.(*Rule).specs...)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
|
if err := manager.DeleteRule(rule1); err != nil {
|
||||||
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, false, rule1.(*Rule).specs...)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
|
if err := manager.DeleteRule(rule2); err != nil {
|
||||||
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkRuleSpecs(t, ipv4Client, false, rule2.(*Rule).specs...)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reset check", func(t *testing.T) {
|
||||||
|
// add second rule
|
||||||
|
ip := net.ParseIP("10.20.0.3")
|
||||||
|
port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}}
|
||||||
|
_, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to add rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := manager.Reset(); err != nil {
|
||||||
|
t.Errorf("failed to reset: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ipv4Client.ChainExists("filter", ChainFilterName)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to drop chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
t.Errorf("chain '%v' still exists after Reset", ChainFilterName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, mustExists bool, rulespec ...string) {
|
||||||
|
exists, err := ipv4Client.Exists("filter", ChainFilterName, rulespec...)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to check rule: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists && mustExists {
|
||||||
|
t.Errorf("rule '%v' does not exist", rulespec)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if exists && !mustExists {
|
||||||
|
t.Errorf("rule '%v' exist", rulespec)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
13
client/firewall/iptables/rule.go
Normal file
13
client/firewall/iptables/rule.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
// Rule to handle management of rules
|
||||||
|
type Rule struct {
|
||||||
|
id string
|
||||||
|
specs []string
|
||||||
|
v6 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRuleID returns the rule id
|
||||||
|
func (r *Rule) GetRuleID() string {
|
||||||
|
return r.id
|
||||||
|
}
|
||||||
24
client/firewall/port.go
Normal file
24
client/firewall/port.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
// PortProtocol is the protocol of the port
|
||||||
|
type PortProtocol string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PortProtocolTCP is the TCP protocol
|
||||||
|
PortProtocolTCP PortProtocol = "tcp"
|
||||||
|
|
||||||
|
// PortProtocolUDP is the UDP protocol
|
||||||
|
PortProtocolUDP PortProtocol = "udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Port of the address for firewall rule
|
||||||
|
type Port struct {
|
||||||
|
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
||||||
|
IsRange bool
|
||||||
|
|
||||||
|
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||||
|
Values []int
|
||||||
|
|
||||||
|
// Proto is the protocol of the port
|
||||||
|
Proto PortProtocol
|
||||||
|
}
|
||||||
@@ -193,6 +193,7 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe`
|
|||||||
Sleep 3000
|
Sleep 3000
|
||||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||||
|
Delete "$INSTDIR\wintun.dll"
|
||||||
RmDir /r "$INSTDIR"
|
RmDir /r "$INSTDIR"
|
||||||
|
|
||||||
SetShellVarContext current
|
SetShellVarContext current
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -28,7 +27,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||||
"Tailscale", "tailscale", "docker", "veth", "br-"}
|
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
|
||||||
|
|
||||||
// ConfigInput carries configuration changes to the client
|
// ConfigInput carries configuration changes to the client
|
||||||
type ConfigInput struct {
|
type ConfigInput struct {
|
||||||
@@ -74,6 +73,62 @@ type Config struct {
|
|||||||
CustomDNSAddress string
|
CustomDNSAddress string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
|
func ReadConfig(configPath string) (*Config, error) {
|
||||||
|
if configFileIsExists(configPath) {
|
||||||
|
config := &Config{}
|
||||||
|
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = WriteOutConfig(configPath, cfg)
|
||||||
|
return cfg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||||
|
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||||
|
if !configFileIsExists(input.ConfigPath) {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
return update(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||||
|
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||||
|
if !configFileIsExists(input.ConfigPath) {
|
||||||
|
log.Infof("generating new config %s", input.ConfigPath)
|
||||||
|
cfg, err := createNewConfig(input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = WriteOutConfig(input.ConfigPath, cfg)
|
||||||
|
return cfg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||||
|
input.PreSharedKey = nil
|
||||||
|
}
|
||||||
|
return update(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateInMemoryConfig generate a new config but do not write out it to the store
|
||||||
|
func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||||
|
return createNewConfig(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteOutConfig write put the prepared config to the given path
|
||||||
|
func WriteOutConfig(path string, config *Config) error {
|
||||||
|
return util.WriteJson(path, config)
|
||||||
|
}
|
||||||
|
|
||||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||||
func createNewConfig(input ConfigInput) (*Config, error) {
|
func createNewConfig(input ConfigInput) (*Config, error) {
|
||||||
wgKey := generateKey()
|
wgKey := generateKey()
|
||||||
@@ -92,14 +147,14 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
CustomDNSAddress: string(input.CustomDNSAddress),
|
CustomDNSAddress: string(input.CustomDNSAddress),
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultManagementURL, err := ParseURL("Management URL", DefaultManagementURL)
|
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.ManagementURL = defaultManagementURL
|
config.ManagementURL = defaultManagementURL
|
||||||
if input.ManagementURL != "" {
|
if input.ManagementURL != "" {
|
||||||
URL, err := ParseURL("Management URL", input.ManagementURL)
|
URL, err := parseURL("Management URL", input.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -110,14 +165,14 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
config.PreSharedKey = *input.PreSharedKey
|
config.PreSharedKey = *input.PreSharedKey
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultAdminURL, err := ParseURL("Admin URL", DefaultAdminURL)
|
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
config.AdminURL = defaultAdminURL
|
config.AdminURL = defaultAdminURL
|
||||||
if input.AdminURL != "" {
|
if input.AdminURL != "" {
|
||||||
newURL, err := ParseURL("Admin Panel URL", input.AdminURL)
|
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -125,49 +180,11 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
config.IFaceBlackList = defaultInterfaceBlacklist
|
config.IFaceBlackList = defaultInterfaceBlacklist
|
||||||
|
|
||||||
err = util.WriteJson(input.ConfigPath, config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseURL parses and validates a service URL
|
func update(input ConfigInput) (*Config, error) {
|
||||||
func ParseURL(serviceName, serviceURL string) (*url.URL, error) {
|
|
||||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing %s URL %s: [%s]", serviceName, serviceURL, err.Error())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if parsedMgmtURL.Scheme != "https" && parsedMgmtURL.Scheme != "http" {
|
|
||||||
return nil, fmt.Errorf(
|
|
||||||
"invalid %s URL provided %s. Supported format [http|https]://[host]:[port]",
|
|
||||||
serviceName, serviceURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if parsedMgmtURL.Port() == "" {
|
|
||||||
switch parsedMgmtURL.Scheme {
|
|
||||||
case "https":
|
|
||||||
parsedMgmtURL.Host = parsedMgmtURL.Host + ":443"
|
|
||||||
case "http":
|
|
||||||
parsedMgmtURL.Host = parsedMgmtURL.Host + ":80"
|
|
||||||
default:
|
|
||||||
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return parsedMgmtURL, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadConfig reads existing configuration and update settings according to input configuration
|
|
||||||
func ReadConfig(input ConfigInput) (*Config, error) {
|
|
||||||
config := &Config{}
|
config := &Config{}
|
||||||
if _, err := os.Stat(input.ConfigPath); os.IsNotExist(err) {
|
|
||||||
return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -178,7 +195,7 @@ func ReadConfig(input ConfigInput) (*Config, error) {
|
|||||||
if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
|
if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
|
||||||
log.Infof("new Management URL provided, updated to %s (old value %s)",
|
log.Infof("new Management URL provided, updated to %s (old value %s)",
|
||||||
input.ManagementURL, config.ManagementURL)
|
input.ManagementURL, config.ManagementURL)
|
||||||
newURL, err := ParseURL("Management URL", input.ManagementURL)
|
newURL, err := parseURL("Management URL", input.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -189,7 +206,7 @@ func ReadConfig(input ConfigInput) (*Config, error) {
|
|||||||
if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
|
if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
|
||||||
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
|
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
|
||||||
input.AdminURL, config.AdminURL)
|
input.AdminURL, config.AdminURL)
|
||||||
newURL, err := ParseURL("Admin Panel URL", input.AdminURL)
|
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -237,18 +254,32 @@ func ReadConfig(input ConfigInput) (*Config, error) {
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConfig reads existing config or generates a new one
|
// parseURL parses and validates a service URL
|
||||||
func GetConfig(input ConfigInput) (*Config, error) {
|
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||||
if _, err := os.Stat(input.ConfigPath); os.IsNotExist(err) {
|
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||||
log.Infof("generating new config %s", input.ConfigPath)
|
if err != nil {
|
||||||
return createNewConfig(input)
|
log.Errorf("failed parsing %s URL %s: [%s]", serviceName, serviceURL, err.Error())
|
||||||
} else {
|
return nil, err
|
||||||
// don't overwrite pre-shared key if we receive asterisks from UI
|
|
||||||
if *input.PreSharedKey == "**********" {
|
|
||||||
input.PreSharedKey = nil
|
|
||||||
}
|
|
||||||
return ReadConfig(input)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if parsedMgmtURL.Scheme != "https" && parsedMgmtURL.Scheme != "http" {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"invalid %s URL provided %s. Supported format [http|https]://[host]:[port]",
|
||||||
|
serviceName, serviceURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedMgmtURL.Port() == "" {
|
||||||
|
switch parsedMgmtURL.Scheme {
|
||||||
|
case "https":
|
||||||
|
parsedMgmtURL.Host = parsedMgmtURL.Host + ":443"
|
||||||
|
case "http":
|
||||||
|
parsedMgmtURL.Host = parsedMgmtURL.Host + ":80"
|
||||||
|
default:
|
||||||
|
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedMgmtURL, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateKey generates a new Wireguard private key
|
// generateKey generates a new Wireguard private key
|
||||||
@@ -260,107 +291,15 @@ func generateKey() string {
|
|||||||
return key.String()
|
return key.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
// don't overwrite pre-shared key if we receive asterisks from UI
|
||||||
type DeviceAuthorizationFlow struct {
|
func isPreSharedKeyHidden(preSharedKey *string) bool {
|
||||||
Provider string
|
if preSharedKey != nil && *preSharedKey == "**********" {
|
||||||
ProviderConfig ProviderConfig
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProviderConfig has all attributes needed to initiate a device authorization flow
|
func configFileIsExists(path string) bool {
|
||||||
type ProviderConfig struct {
|
_, err := os.Stat(path)
|
||||||
// ClientID An IDP application client id
|
return !os.IsNotExist(err)
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Domain An IDP API domain
|
|
||||||
// Deprecated. Use OIDCConfigEndpoint instead
|
|
||||||
Domain string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
|
||||||
DeviceAuthEndpoint string
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (DeviceAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTlsEnabled bool
|
|
||||||
if config.ManagementURL.Scheme == "https" {
|
|
||||||
mgmTlsEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", config.ManagementURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", config.ManagementURL.String(), err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
} else {
|
|
||||||
log.Errorf("failed to retrieve device flow: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
|
||||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
|
||||||
|
|
||||||
ProviderConfig: ProviderConfig{
|
|
||||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
|
||||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return deviceAuthorizationFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isProviderConfigValid(config ProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.Audience == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
|
||||||
}
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.DeviceAuthEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
func TestGetConfig(t *testing.T) {
|
func TestGetConfig(t *testing.T) {
|
||||||
// case 1: new default config has to be generated
|
// case 1: new default config has to be generated
|
||||||
config, err := GetConfig(ConfigInput{
|
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ func TestGetConfig(t *testing.T) {
|
|||||||
preSharedKey := "preSharedKey"
|
preSharedKey := "preSharedKey"
|
||||||
|
|
||||||
// case 2: new config has to be generated
|
// case 2: new config has to be generated
|
||||||
config, err = GetConfig(ConfigInput{
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: path,
|
ConfigPath: path,
|
||||||
@@ -50,7 +50,7 @@ func TestGetConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// case 3: existing config -> fetch it
|
// case 3: existing config -> fetch it
|
||||||
config, err = GetConfig(ConfigInput{
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: path,
|
ConfigPath: path,
|
||||||
@@ -65,7 +65,7 @@ func TestGetConfig(t *testing.T) {
|
|||||||
|
|
||||||
// case 4: existing config, but new managementURL has been provided -> update config
|
// case 4: existing config, but new managementURL has been provided -> update config
|
||||||
newManagementURL := "https://test.newManagement.url:33071"
|
newManagementURL := "https://test.newManagement.url:33071"
|
||||||
config, err = GetConfig(ConfigInput{
|
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||||
ManagementURL: newManagementURL,
|
ManagementURL: newManagementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
ConfigPath: path,
|
ConfigPath: path,
|
||||||
@@ -85,3 +85,40 @@ func TestGetConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL)
|
assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHiddenPreSharedKey(t *testing.T) {
|
||||||
|
hidden := "**********"
|
||||||
|
samplePreSharedKey := "mysecretpresharedkey"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
preSharedKey *string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"nil", nil, ""},
|
||||||
|
{"hidden", &hidden, ""},
|
||||||
|
{"filled", &samplePreSharedKey, samplePreSharedKey},
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate default cfg
|
||||||
|
cfgFile := filepath.Join(t.TempDir(), "config.json")
|
||||||
|
_, _ = UpdateOrCreateConfig(ConfigInput{
|
||||||
|
ConfigPath: cfgFile,
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||||
|
ConfigPath: cfgFile,
|
||||||
|
PreSharedKey: tt.preSharedKey,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get cfg: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.PreSharedKey != tt.want {
|
||||||
|
t.Fatalf("invalid preshared key: '%s', expected: '%s' ", cfg.PreSharedKey, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,25 +6,24 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Status) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error {
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@@ -60,9 +59,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
managementURL := config.ManagementURL.String()
|
defer statusRecorder.ClientStop()
|
||||||
statusRecorder.MarkManagementDisconnected(managementURL)
|
|
||||||
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
// if context cancelled we not start new backoff cycle
|
// if context cancelled we not start new backoff cycle
|
||||||
select {
|
select {
|
||||||
@@ -75,7 +72,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
|
|
||||||
engineCtx, cancel := context.WithCancel(ctx)
|
engineCtx, cancel := context.WithCancel(ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
statusRecorder.MarkManagementDisconnected(managementURL)
|
statusRecorder.MarkManagementDisconnected()
|
||||||
statusRecorder.CleanLocalPeerState()
|
statusRecorder.CleanLocalPeerState()
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -85,6 +82,9 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||||
}
|
}
|
||||||
|
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
|
||||||
|
mgmClient.SetConnStateListener(mgmNotifier)
|
||||||
|
|
||||||
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
|
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
|
||||||
defer func() {
|
defer func() {
|
||||||
err = mgmClient.Close()
|
err = mgmClient.Close()
|
||||||
@@ -103,12 +103,12 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
}
|
}
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
statusRecorder.MarkManagementConnected(managementURL)
|
statusRecorder.MarkManagementConnected()
|
||||||
|
|
||||||
localPeerState := nbStatus.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: iface.WireguardModuleIsLoaded(),
|
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,8 +119,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
||||||
)
|
)
|
||||||
|
|
||||||
statusRecorder.MarkSignalDisconnected(signalURL)
|
statusRecorder.UpdateSignalAddress(signalURL)
|
||||||
defer statusRecorder.MarkSignalDisconnected(signalURL)
|
|
||||||
|
statusRecorder.MarkSignalDisconnected()
|
||||||
|
defer statusRecorder.MarkSignalDisconnected()
|
||||||
|
|
||||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
||||||
@@ -135,11 +137,14 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
statusRecorder.MarkSignalConnected(signalURL)
|
signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
|
||||||
|
signalClient.SetConnStateListener(signalNotifier)
|
||||||
|
|
||||||
|
statusRecorder.MarkSignalConnected()
|
||||||
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
|
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@@ -155,7 +160,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
|
statusRecorder.ClientStart()
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
backOff.Reset()
|
backOff.Reset()
|
||||||
|
|
||||||
@@ -186,11 +194,13 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*EngineConfig, error) {
|
||||||
|
|
||||||
engineConf := &EngineConfig{
|
engineConf := &EngineConfig{
|
||||||
WgIfaceName: config.WgIface,
|
WgIfaceName: config.WgIface,
|
||||||
WgAddr: peerConfig.Address,
|
WgAddr: peerConfig.Address,
|
||||||
|
TunAdapter: tunAdapter,
|
||||||
|
IFaceDiscover: iFaceDiscover,
|
||||||
IFaceBlackList: config.IFaceBlackList,
|
IFaceBlackList: config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
@@ -251,7 +261,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
|
|||||||
// The check is performed only for the NetBird's managed version.
|
// The check is performed only for the NetBird's managed version.
|
||||||
func UpdateOldManagementPort(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
func UpdateOldManagementPort(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||||
|
|
||||||
defaultManagementURL, err := ParseURL("Management URL", DefaultManagementURL)
|
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -273,7 +283,7 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
|
|||||||
|
|
||||||
if mgmTlsEnabled && config.ManagementURL.Port() == fmt.Sprintf("%d", ManagementLegacyPort) {
|
if mgmTlsEnabled && config.ManagementURL.Port() == fmt.Sprintf("%d", ManagementLegacyPort) {
|
||||||
|
|
||||||
newURL, err := ParseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||||
config.ManagementURL.Scheme, config.ManagementURL.Hostname(), 443))
|
config.ManagementURL.Scheme, config.ManagementURL.Hostname(), 443))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -307,7 +317,7 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// everything is alright => update the config
|
// everything is alright => update the config
|
||||||
newConfig, err := ReadConfig(ConfigInput{
|
newConfig, err := UpdateConfig(ConfigInput{
|
||||||
ManagementURL: newURL.String(),
|
ManagementURL: newURL.String(),
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
})
|
})
|
||||||
@@ -322,3 +332,15 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
|
|||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||||
|
var sri interface{} = statusRecorder
|
||||||
|
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)
|
||||||
|
return mgmNotifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal.ConnStateNotifier {
|
||||||
|
var sri interface{} = statusRecorder
|
||||||
|
notifier, _ := sri.(signal.ConnStateNotifier)
|
||||||
|
return notifier
|
||||||
|
}
|
||||||
|
|||||||
134
client/internal/device_auth.go
Normal file
134
client/internal/device_auth.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||||
|
type DeviceAuthorizationFlow struct {
|
||||||
|
Provider string
|
||||||
|
ProviderConfig ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderConfig has all attributes needed to initiate a device authorization flow
|
||||||
|
type ProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Domain An IDP API domain
|
||||||
|
// Deprecated. Use OIDCConfigEndpoint instead
|
||||||
|
Domain string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||||
|
DeviceAuthEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||||
|
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
||||||
|
// validate our peer's Wireguard PRIVATE key
|
||||||
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||||
|
return DeviceAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var mgmTLSEnabled bool
|
||||||
|
if mgmURL.Scheme == "https" {
|
||||||
|
mgmTLSEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
||||||
|
return DeviceAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err = mgmClient.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to close the Management service client %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
serverKey, err := mgmClient.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return DeviceAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||||
|
return DeviceAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve device flow: %v", err)
|
||||||
|
return DeviceAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
||||||
|
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||||
|
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||||
|
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||||
|
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||||
|
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||||
|
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||||
|
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||||
|
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||||
|
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// keep compatibility with older management versions
|
||||||
|
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||||
|
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||||
|
}
|
||||||
|
|
||||||
|
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||||
|
if err != nil {
|
||||||
|
return DeviceAuthorizationFlow{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return deviceAuthorizationFlow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProviderConfigValid(config ProviderConfig) error {
|
||||||
|
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
if config.Audience == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Audience")
|
||||||
|
}
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.DeviceAuthEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -3,8 +3,9 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -14,6 +15,7 @@ const (
|
|||||||
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
||||||
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
||||||
fileMaxLineCharsLimit = 256
|
fileMaxLineCharsLimit = 256
|
||||||
@@ -66,7 +68,7 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
var searchDomains string
|
var searchDomains string
|
||||||
appendedDomains := 0
|
appendedDomains := 0
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
if dConf.matchOnly {
|
if dConf.matchOnly || dConf.disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||||
|
|||||||
@@ -2,8 +2,9 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
@@ -19,6 +20,7 @@ type hostDNSConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type domainConfig struct {
|
type domainConfig struct {
|
||||||
|
disabled bool
|
||||||
domain string
|
domain string
|
||||||
matchOnly bool
|
matchOnly bool
|
||||||
}
|
}
|
||||||
@@ -56,6 +58,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostD
|
|||||||
serverPort: port,
|
serverPort: port,
|
||||||
}
|
}
|
||||||
for _, nsConfig := range dnsConfig.NameServerGroups {
|
for _, nsConfig := range dnsConfig.NameServerGroups {
|
||||||
|
if len(nsConfig.NameServers) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if nsConfig.Primary {
|
if nsConfig.Primary {
|
||||||
config.routeAll = true
|
config.routeAll = true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -61,6 +62,9 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
)
|
)
|
||||||
|
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if dConf.matchOnly {
|
if dConf.matchOnly {
|
||||||
matchDomains = append(matchDomains, dConf.domain)
|
matchDomains = append(matchDomains, dConf.domain)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -63,6 +64,9 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
)
|
)
|
||||||
|
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !dConf.matchOnly {
|
if !dConf.matchOnly {
|
||||||
searchDomains = append(searchDomains, dConf.domain)
|
searchDomains = append(searchDomains, dConf.domain)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type registrationMap map[string]struct{}
|
||||||
|
|
||||||
type localResolver struct {
|
type localResolver struct {
|
||||||
registeredMap registrationMap
|
registeredMap registrationMap
|
||||||
records sync.Map
|
records sync.Map
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/miekg/dns"
|
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockResponseWriter struct {
|
type mockResponseWriter struct {
|
||||||
|
|||||||
@@ -4,14 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"regexp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net/netip"
|
|
||||||
"regexp"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -75,12 +76,12 @@ func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager,
|
|||||||
}
|
}
|
||||||
defer closeConn()
|
defer closeConn()
|
||||||
var s string
|
var s string
|
||||||
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.GetName()).Store(&s)
|
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.Name()).Store(&s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.GetName())
|
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.Name())
|
||||||
|
|
||||||
return &networkManagerDbusConfigurator{
|
return &networkManagerDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
@@ -106,6 +107,9 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
|
|||||||
matchDomains []string
|
matchDomains []string
|
||||||
)
|
)
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if dConf.matchOnly {
|
if dConf.matchOnly {
|
||||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
|
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const resolvconfCommand = "resolvconf"
|
const resolvconfCommand = "resolvconf"
|
||||||
@@ -16,7 +17,7 @@ type resolvconf struct {
|
|||||||
|
|
||||||
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
return &resolvconf{
|
return &resolvconf{
|
||||||
ifaceName: wgInterface.GetName(),
|
ifaceName: wgInterface.Name(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
var searchDomains string
|
var searchDomains string
|
||||||
appendedDomains := 0
|
appendedDomains := 0
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
if dConf.matchOnly {
|
if dConf.matchOnly || dConf.disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,26 +1,6 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultPort = 53
|
|
||||||
customPort = 5053
|
|
||||||
defaultIP = "127.0.0.1"
|
|
||||||
customIP = "127.0.0.153"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
@@ -28,357 +8,3 @@ type Server interface {
|
|||||||
Stop()
|
Stop()
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultServer dns server object
|
|
||||||
type DefaultServer struct {
|
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
mux sync.Mutex
|
|
||||||
server *dns.Server
|
|
||||||
dnsMux *dns.ServeMux
|
|
||||||
dnsMuxMap registrationMap
|
|
||||||
localResolver *localResolver
|
|
||||||
wgInterface *iface.WGIface
|
|
||||||
hostManager hostManager
|
|
||||||
updateSerial uint64
|
|
||||||
listenerIsRunning bool
|
|
||||||
runtimePort int
|
|
||||||
runtimeIP string
|
|
||||||
previousConfigHash uint64
|
|
||||||
customAddress *netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
type registrationMap map[string]struct{}
|
|
||||||
|
|
||||||
type muxUpdate struct {
|
|
||||||
domain string
|
|
||||||
handler dns.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
|
||||||
mux := dns.NewServeMux()
|
|
||||||
|
|
||||||
dnsServer := &dns.Server{
|
|
||||||
Net: "udp",
|
|
||||||
Handler: mux,
|
|
||||||
UDPSize: 65535,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, stop := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
var addrPort *netip.AddrPort
|
|
||||||
if customAddress != "" {
|
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
|
||||||
if err != nil {
|
|
||||||
stop()
|
|
||||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
|
||||||
}
|
|
||||||
addrPort = &parsedAddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: stop,
|
|
||||||
server: dnsServer,
|
|
||||||
dnsMux: mux,
|
|
||||||
dnsMuxMap: make(registrationMap),
|
|
||||||
localResolver: &localResolver{
|
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
},
|
|
||||||
wgInterface: wgInterface,
|
|
||||||
runtimePort: defaultPort,
|
|
||||||
customAddress: addrPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
hostmanager, err := newHostManager(wgInterface)
|
|
||||||
if err != nil {
|
|
||||||
stop()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defaultServer.hostManager = hostmanager
|
|
||||||
return defaultServer, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start runs the listener in a go routine
|
|
||||||
func (s *DefaultServer) Start() {
|
|
||||||
|
|
||||||
if s.customAddress != nil {
|
|
||||||
s.runtimeIP = s.customAddress.Addr().String()
|
|
||||||
s.runtimePort = int(s.customAddress.Port())
|
|
||||||
} else {
|
|
||||||
ip, port, err := s.getFirstListenerAvailable()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtimeIP = ip
|
|
||||||
s.runtimePort = port
|
|
||||||
}
|
|
||||||
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
|
|
||||||
log.Debugf("starting dns on %s", s.server.Addr)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
err := s.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
|
||||||
ips := []string{defaultIP, customIP}
|
|
||||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
|
||||||
ips = append([]string{s.wgInterface.GetAddress().IP.String()}, ips...)
|
|
||||||
}
|
|
||||||
ports := []int{defaultPort, customPort}
|
|
||||||
for _, port := range ports {
|
|
||||||
for _, ip := range ips {
|
|
||||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
|
||||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err == nil {
|
|
||||||
err = probeListener.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
|
||||||
}
|
|
||||||
return ip, port, nil
|
|
||||||
}
|
|
||||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
|
||||||
s.listenerIsRunning = running
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the server
|
|
||||||
func (s *DefaultServer) Stop() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
s.stop()
|
|
||||||
|
|
||||||
err := s.hostManager.restoreHostDNS()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.stopListener()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) stopListener() error {
|
|
||||||
if !s.listenerIsRunning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := s.server.ShutdownContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
|
||||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
|
||||||
select {
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
log.Infof("not updating DNS server as context is closed")
|
|
||||||
return s.ctx.Err()
|
|
||||||
default:
|
|
||||||
if serial < s.updateSerial {
|
|
||||||
return fmt.Errorf("not applying dns update, error: "+
|
|
||||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
|
||||||
}
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
|
||||||
ZeroNil: true,
|
|
||||||
IgnoreZeroValue: true,
|
|
||||||
SlicesAsSets: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.previousConfigHash == hash {
|
|
||||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
|
||||||
s.updateSerial = serial
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// is the service should be disabled, we stop the listener
|
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
|
||||||
if !update.ServiceEnable {
|
|
||||||
err := s.stopListener()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
} else if !s.listenerIsRunning {
|
|
||||||
s.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
|
||||||
|
|
||||||
s.updateMux(muxUpdates)
|
|
||||||
s.updateLocalResolver(localRecords)
|
|
||||||
|
|
||||||
err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort))
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.updateSerial = serial
|
|
||||||
s.previousConfigHash = hash
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
|
||||||
|
|
||||||
if len(customZone.Records) == 0 {
|
|
||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
|
||||||
}
|
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: customZone.Domain,
|
|
||||||
handler: s.localResolver,
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
|
||||||
var class uint16 = dns.ClassINET
|
|
||||||
if record.Class != nbdns.DefaultClass {
|
|
||||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
|
||||||
}
|
|
||||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
|
||||||
localRecords[key] = record
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, localRecords, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
for _, nsGroup := range nameServerGroups {
|
|
||||||
if len(nsGroup.NameServers) == 0 {
|
|
||||||
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
|
|
||||||
}
|
|
||||||
handler := &upstreamResolver{
|
|
||||||
parentCTX: s.ctx,
|
|
||||||
upstreamClient: &dns.Client{},
|
|
||||||
upstreamTimeout: defaultUpstreamTimeout,
|
|
||||||
}
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
|
||||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
|
||||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(handler.upstreamServers) == 0 {
|
|
||||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if nsGroup.Primary {
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: nbdns.RootZone,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nsGroup.Domains) == 0 {
|
|
||||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
if domain == "" {
|
|
||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
|
||||||
}
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: domain,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|
||||||
muxUpdateMap := make(registrationMap)
|
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
|
||||||
s.registerMux(update.domain, update.handler)
|
|
||||||
muxUpdateMap[update.domain] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for key := range s.dnsMuxMap {
|
|
||||||
_, found := muxUpdateMap[key]
|
|
||||||
if !found {
|
|
||||||
s.deregisterMux(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
|
||||||
for key := range s.localResolver.registeredMap {
|
|
||||||
_, found := update[key]
|
|
||||||
if !found {
|
|
||||||
s.localResolver.deleteRecord(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMap := make(registrationMap)
|
|
||||||
for key, record := range update {
|
|
||||||
err := s.localResolver.registerRecord(record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
|
||||||
}
|
|
||||||
updatedMap[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.localResolver.registeredMap = updatedMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNSHostPort(ns nbdns.NameServer) string {
|
|
||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
|
||||||
s.dnsMux.Handle(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
|
||||||
s.dnsMux.HandleRemove(pattern)
|
|
||||||
}
|
|
||||||
|
|||||||
32
client/internal/dns/server_android.go
Normal file
32
client/internal/dns/server_android.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultServer dummy dns server
|
||||||
|
type DefaultServer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServer On Android the DNS feature is not supported yet
|
||||||
|
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
|
return &DefaultServer{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start dummy implementation
|
||||||
|
func (s DefaultServer) Start() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop dummy implementation
|
||||||
|
func (s DefaultServer) Stop() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDNSServer dummy implementation
|
||||||
|
func (s DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
465
client/internal/dns/server_nonandroid.go
Normal file
465
client/internal/dns/server_nonandroid.go
Normal file
@@ -0,0 +1,465 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultPort = 53
|
||||||
|
customPort = 5053
|
||||||
|
defaultIP = "127.0.0.1"
|
||||||
|
customIP = "127.0.0.153"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultServer dns server object
|
||||||
|
type DefaultServer struct {
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
upstreamCtxCancel context.CancelFunc
|
||||||
|
mux sync.Mutex
|
||||||
|
server *dns.Server
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
dnsMuxMap registrationMap
|
||||||
|
localResolver *localResolver
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
hostManager hostManager
|
||||||
|
updateSerial uint64
|
||||||
|
listenerIsRunning bool
|
||||||
|
runtimePort int
|
||||||
|
runtimeIP string
|
||||||
|
previousConfigHash uint64
|
||||||
|
currentConfig hostDNSConfig
|
||||||
|
customAddress *netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
type muxUpdate struct {
|
||||||
|
domain string
|
||||||
|
handler dns.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServer returns a new dns server
|
||||||
|
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
dnsServer := &dns.Server{
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
UDPSize: 65535,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
var addrPort *netip.AddrPort
|
||||||
|
if customAddress != "" {
|
||||||
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
|
if err != nil {
|
||||||
|
stop()
|
||||||
|
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
||||||
|
}
|
||||||
|
addrPort = &parsedAddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultServer := &DefaultServer{
|
||||||
|
ctx: ctx,
|
||||||
|
ctxCancel: stop,
|
||||||
|
server: dnsServer,
|
||||||
|
dnsMux: mux,
|
||||||
|
dnsMuxMap: make(registrationMap),
|
||||||
|
localResolver: &localResolver{
|
||||||
|
registeredMap: make(registrationMap),
|
||||||
|
},
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
runtimePort: defaultPort,
|
||||||
|
customAddress: addrPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
hostmanager, err := newHostManager(wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
stop()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defaultServer.hostManager = hostmanager
|
||||||
|
return defaultServer, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start runs the listener in a go routine
|
||||||
|
func (s *DefaultServer) Start() {
|
||||||
|
if s.customAddress != nil {
|
||||||
|
s.runtimeIP = s.customAddress.Addr().String()
|
||||||
|
s.runtimePort = int(s.customAddress.Port())
|
||||||
|
} else {
|
||||||
|
ip, port, err := s.getFirstListenerAvailable()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtimeIP = ip
|
||||||
|
s.runtimePort = port
|
||||||
|
}
|
||||||
|
|
||||||
|
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s.setListenerStatus(true)
|
||||||
|
defer s.setListenerStatus(false)
|
||||||
|
|
||||||
|
err := s.server.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||||
|
ips := []string{defaultIP, customIP}
|
||||||
|
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||||
|
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||||
|
}
|
||||||
|
ports := []int{defaultPort, customPort}
|
||||||
|
for _, port := range ports {
|
||||||
|
for _, ip := range ips {
|
||||||
|
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err == nil {
|
||||||
|
err = probeListener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||||
|
}
|
||||||
|
return ip, port, nil
|
||||||
|
}
|
||||||
|
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||||
|
s.listenerIsRunning = running
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the server
|
||||||
|
func (s *DefaultServer) Stop() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
s.ctxCancel()
|
||||||
|
|
||||||
|
err := s.hostManager.restoreHostDNS()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.stopListener()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) stopListener() error {
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := s.server.ShutdownContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDNSServer processes an update received from the management service
|
||||||
|
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
log.Infof("not updating DNS server as context is closed")
|
||||||
|
return s.ctx.Err()
|
||||||
|
default:
|
||||||
|
if serial < s.updateSerial {
|
||||||
|
return fmt.Errorf("not applying dns update, error: "+
|
||||||
|
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||||
|
}
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||||
|
ZeroNil: true,
|
||||||
|
IgnoreZeroValue: true,
|
||||||
|
SlicesAsSets: true,
|
||||||
|
UseStringer: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.previousConfigHash == hash {
|
||||||
|
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||||
|
s.updateSerial = serial
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.applyConfiguration(update); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.updateSerial = serial
|
||||||
|
s.previousConfigHash = hash
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
|
// is the service should be disabled, we stop the listener
|
||||||
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
|
if !update.ServiceEnable {
|
||||||
|
err := s.stopListener()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
} else if !s.listenerIsRunning {
|
||||||
|
s.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||||
|
|
||||||
|
s.updateMux(muxUpdates)
|
||||||
|
s.updateLocalResolver(localRecords)
|
||||||
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||||
|
var muxUpdates []muxUpdate
|
||||||
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
|
for _, customZone := range customZones {
|
||||||
|
|
||||||
|
if len(customZone.Records) == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: customZone.Domain,
|
||||||
|
handler: s.localResolver,
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, record := range customZone.Records {
|
||||||
|
var class uint16 = dns.ClassINET
|
||||||
|
if record.Class != nbdns.DefaultClass {
|
||||||
|
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||||
|
}
|
||||||
|
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||||
|
localRecords[key] = record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return muxUpdates, localRecords, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||||
|
// clean up the previous upstream resolver
|
||||||
|
if s.upstreamCtxCancel != nil {
|
||||||
|
s.upstreamCtxCancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
var muxUpdates []muxUpdate
|
||||||
|
for _, nsGroup := range nameServerGroups {
|
||||||
|
if len(nsGroup.NameServers) == 0 {
|
||||||
|
log.Warn("received a nameserver group with empty nameserver list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
|
||||||
|
|
||||||
|
handler := newUpstreamResolver(ctx)
|
||||||
|
for _, ns := range nsGroup.NameServers {
|
||||||
|
if ns.NSType != nbdns.UDPNameServerType {
|
||||||
|
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||||
|
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(handler.upstreamServers) == 0 {
|
||||||
|
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// when upstream fails to resolve domain several times over all it servers
|
||||||
|
// it will calls this hook to exclude self from the configuration and
|
||||||
|
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||||
|
// because it is temporal deactivation until next try
|
||||||
|
//
|
||||||
|
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||||
|
// everything we need here is just to re-apply current configuration because it already
|
||||||
|
// contains this upstream settings (temporal deactivation not removed it)
|
||||||
|
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||||
|
|
||||||
|
if nsGroup.Primary {
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
handler: handler,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(nsGroup.Domains) == 0 {
|
||||||
|
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
|
}
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: domain,
|
||||||
|
handler: handler,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return muxUpdates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
|
muxUpdateMap := make(registrationMap)
|
||||||
|
|
||||||
|
for _, update := range muxUpdates {
|
||||||
|
s.registerMux(update.domain, update.handler)
|
||||||
|
muxUpdateMap[update.domain] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range s.dnsMuxMap {
|
||||||
|
_, found := muxUpdateMap[key]
|
||||||
|
if !found {
|
||||||
|
s.deregisterMux(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.dnsMuxMap = muxUpdateMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
|
for key := range s.localResolver.registeredMap {
|
||||||
|
_, found := update[key]
|
||||||
|
if !found {
|
||||||
|
s.localResolver.deleteRecord(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedMap := make(registrationMap)
|
||||||
|
for key, record := range update {
|
||||||
|
err := s.localResolver.registerRecord(record)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||||
|
}
|
||||||
|
updatedMap[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.localResolver.registeredMap = updatedMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNSHostPort(ns nbdns.NameServer) string {
|
||||||
|
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
|
func (s *DefaultServer) upstreamCallbacks(
|
||||||
|
nsGroup *nbdns.NameServerGroup,
|
||||||
|
handler dns.Handler,
|
||||||
|
) (deactivate func(), reactivate func()) {
|
||||||
|
var removeIndex map[string]int
|
||||||
|
deactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Info("temporary deactivate nameservers group due timeout")
|
||||||
|
|
||||||
|
removeIndex = make(map[string]int)
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
removeIndex[domain] = -1
|
||||||
|
}
|
||||||
|
if nsGroup.Primary {
|
||||||
|
removeIndex[nbdns.RootZone] = -1
|
||||||
|
s.currentConfig.routeAll = false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, item := range s.currentConfig.domains {
|
||||||
|
if _, found := removeIndex[item.domain]; found {
|
||||||
|
s.currentConfig.domains[i].disabled = true
|
||||||
|
s.deregisterMux(item.domain)
|
||||||
|
removeIndex[item.domain] = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
for domain, i := range removeIndex {
|
||||||
|
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.currentConfig.domains[i].disabled = false
|
||||||
|
s.registerMux(domain, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Debug("reactivate temporary disabled nameserver group")
|
||||||
|
|
||||||
|
if nsGroup.Primary {
|
||||||
|
s.currentConfig.routeAll = true
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -3,13 +3,18 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/miekg/dns"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
var zoneRecords = []nbdns.SimpleRecord{
|
var zoneRecords = []nbdns.SimpleRecord{
|
||||||
@@ -23,7 +28,6 @@ var zoneRecords = []nbdns.SimpleRecord{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
@@ -198,7 +202,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
|
|
||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU)
|
newNet, err := stdnet.NewNet(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -263,7 +271,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSServerStartStop(t *testing.T) {
|
func TestDNSServerStartStop(t *testing.T) {
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
addrPort string
|
addrPort string
|
||||||
@@ -333,6 +340,72 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
|
hostManager := &mockHostConfigurator{}
|
||||||
|
server := DefaultServer{
|
||||||
|
dnsMux: dns.DefaultServeMux,
|
||||||
|
localResolver: &localResolver{
|
||||||
|
registeredMap: make(registrationMap),
|
||||||
|
},
|
||||||
|
hostManager: hostManager,
|
||||||
|
currentConfig: hostDNSConfig{
|
||||||
|
domains: []domainConfig{
|
||||||
|
{false, "domain0", false},
|
||||||
|
{false, "domain1", false},
|
||||||
|
{false, "domain2", false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var domainsUpdate string
|
||||||
|
hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error {
|
||||||
|
domains := []string{}
|
||||||
|
for _, item := range config.domains {
|
||||||
|
if item.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domains = append(domains, item.domain)
|
||||||
|
}
|
||||||
|
domainsUpdate = strings.Join(domains, ",")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"domain1"},
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
deactivate()
|
||||||
|
expected := "domain0,domain2"
|
||||||
|
domains := []string{}
|
||||||
|
for _, item := range server.currentConfig.domains {
|
||||||
|
if item.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domains = append(domains, item.domain)
|
||||||
|
}
|
||||||
|
got := strings.Join(domains, ",")
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("expected domains list: %q, got %q", expected, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
reactivate()
|
||||||
|
expected = "domain0,domain1,domain2"
|
||||||
|
domains = []string{}
|
||||||
|
for _, item := range server.currentConfig.domains {
|
||||||
|
if item.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domains = append(domains, item.domain)
|
||||||
|
}
|
||||||
|
got = strings.Join(domains, ",")
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
@@ -351,11 +424,11 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe
|
|||||||
UDPSize: 65535,
|
UDPSize: 65535,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, stop := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
|
|
||||||
return &DefaultServer{
|
return &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
stop: stop,
|
ctxCancel: cancel,
|
||||||
server: dnsServer,
|
server: dnsServer,
|
||||||
dnsMux: mux,
|
dnsMux: mux,
|
||||||
dnsMuxMap: make(registrationMap),
|
dnsMuxMap: make(registrationMap),
|
||||||
|
|||||||
@@ -3,15 +3,16 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -50,7 +51,7 @@ type systemdDbusLinkDomainsInput struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
iface, err := net.InterfaceByName(wgInterface.GetName())
|
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -95,6 +96,9 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
domainsInput []systemdDbusLinkDomainsInput
|
domainsInput []systemdDbusLinkDomainsInput
|
||||||
)
|
)
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||||
Domain: dns.Fqdn(dConf.domain),
|
Domain: dns.Fqdn(dConf.domain),
|
||||||
MatchOnly: dConf.matchOnly,
|
MatchOnly: dConf.matchOnly,
|
||||||
|
|||||||
@@ -3,44 +3,73 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultUpstreamTimeout = 15 * time.Second
|
const (
|
||||||
|
failsTillDeact = int32(3)
|
||||||
|
reactivatePeriod = time.Minute
|
||||||
|
upstreamTimeout = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
parentCTX context.Context
|
ctx context.Context
|
||||||
upstreamClient *dns.Client
|
upstreamClient *dns.Client
|
||||||
upstreamServers []string
|
upstreamServers []string
|
||||||
upstreamTimeout time.Duration
|
disabled bool
|
||||||
|
failsCount atomic.Int32
|
||||||
|
failsTillDeact int32
|
||||||
|
mutex sync.Mutex
|
||||||
|
reactivatePeriod time.Duration
|
||||||
|
upstreamTimeout time.Duration
|
||||||
|
|
||||||
|
deactivate func()
|
||||||
|
reactivate func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUpstreamResolver(ctx context.Context) *upstreamResolver {
|
||||||
|
return &upstreamResolver{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: &dns.Client{},
|
||||||
|
upstreamTimeout: upstreamTimeout,
|
||||||
|
reactivatePeriod: reactivatePeriod,
|
||||||
|
failsTillDeact: failsTillDeact,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
defer u.checkUpstreamFails()
|
||||||
|
|
||||||
log.Tracef("received an upstream question: %#v", r.Question[0])
|
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-u.parentCTX.Done():
|
case <-u.ctx.Done():
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, upstream := range u.upstreamServers {
|
for _, upstream := range u.upstreamServers {
|
||||||
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
|
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||||
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == context.DeadlineExceeded || isTimeout(err) {
|
if err == context.DeadlineExceeded || isTimeout(err) {
|
||||||
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
|
log.WithError(err).WithField("upstream", upstream).
|
||||||
|
Warn("got an error while connecting to upstream")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
|
u.failsCount.Add(1)
|
||||||
|
log.WithError(err).WithField("upstream", upstream).
|
||||||
|
Error("got an error while querying the upstream")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,11 +77,58 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
err = w.WriteMsg(rm)
|
err = w.WriteMsg(rm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
|
log.WithError(err).Error("got an error while writing the upstream resolver response")
|
||||||
}
|
}
|
||||||
|
// count the fails only if they happen sequentially
|
||||||
|
u.failsCount.Store(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Errorf("all queries to the upstream nameservers failed with timeout")
|
u.failsCount.Add(1)
|
||||||
|
log.Error("all queries to the upstream nameservers failed with timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
||||||
|
//
|
||||||
|
// If fails count is greater that failsTillDeact, upstream resolving
|
||||||
|
// will be disabled for reactivatePeriod, after that time period fails counter
|
||||||
|
// will be reset and upstream will be reactivated.
|
||||||
|
func (u *upstreamResolver) checkUpstreamFails() {
|
||||||
|
u.mutex.Lock()
|
||||||
|
defer u.mutex.Unlock()
|
||||||
|
|
||||||
|
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-u.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
|
||||||
|
u.deactivate()
|
||||||
|
u.disabled = true
|
||||||
|
go u.waitUntilReactivation()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitUntilReactivation reset fails counter and activates upstream resolving
|
||||||
|
func (u *upstreamResolver) waitUntilReactivation() {
|
||||||
|
timer := time.NewTimer(u.reactivatePeriod)
|
||||||
|
defer func() {
|
||||||
|
if !timer.Stop() {
|
||||||
|
<-timer.C
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-u.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-timer.C:
|
||||||
|
log.Info("upstream resolving is reactivated")
|
||||||
|
u.failsCount.Store(0)
|
||||||
|
u.reactivate()
|
||||||
|
u.disabled = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isTimeout returns true if the given error is a network timeout error.
|
// isTimeout returns true if the given error is a network timeout error.
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
name: "Should Resolve A Record",
|
name: "Should Resolve A Record",
|
||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||||
timeout: defaultUpstreamTimeout,
|
timeout: upstreamTimeout,
|
||||||
expectedAnswer: "1.1.1.1",
|
expectedAnswer: "1.1.1.1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -45,7 +45,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||||
cancelCTX: true,
|
cancelCTX: true,
|
||||||
timeout: defaultUpstreamTimeout,
|
timeout: upstreamTimeout,
|
||||||
responseShouldBeNil: true,
|
responseShouldBeNil: true,
|
||||||
},
|
},
|
||||||
//{
|
//{
|
||||||
@@ -65,12 +65,9 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver := &upstreamResolver{
|
resolver := newUpstreamResolver(ctx)
|
||||||
parentCTX: ctx,
|
resolver.upstreamServers = testCase.InputServers
|
||||||
upstreamClient: &dns.Client{},
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
upstreamServers: testCase.InputServers,
|
|
||||||
upstreamTimeout: testCase.timeout,
|
|
||||||
}
|
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
cancel()
|
cancel()
|
||||||
} else {
|
} else {
|
||||||
@@ -108,3 +105,52 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||||
|
resolver := newUpstreamResolver(context.TODO())
|
||||||
|
resolver.upstreamServers = []string{"0.0.0.0:-1"}
|
||||||
|
resolver.failsTillDeact = 0
|
||||||
|
resolver.reactivatePeriod = time.Microsecond * 100
|
||||||
|
|
||||||
|
responseWriter := &mockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
failed := false
|
||||||
|
resolver.deactivate = func() {
|
||||||
|
failed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
reactivated := false
|
||||||
|
resolver.reactivate = func() {
|
||||||
|
reactivated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
|
||||||
|
|
||||||
|
if !failed {
|
||||||
|
t.Errorf("expected that resolving was deactivated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resolver.disabled {
|
||||||
|
t.Errorf("resolver should be disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 200)
|
||||||
|
|
||||||
|
if !reactivated {
|
||||||
|
t.Errorf("expected that resolving was reactivated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resolver.failsCount.Load() != 0 {
|
||||||
|
t.Errorf("fails count after reactivation should be 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resolver.disabled {
|
||||||
|
t.Errorf("should be enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,24 +12,23 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||||
@@ -47,6 +46,10 @@ var ErrResetConnection = fmt.Errorf("reset connection")
|
|||||||
type EngineConfig struct {
|
type EngineConfig struct {
|
||||||
WgPort int
|
WgPort int
|
||||||
WgIfaceName string
|
WgIfaceName string
|
||||||
|
// TunAdapter is option. It is necessary for mobile version.
|
||||||
|
TunAdapter iface.TunAdapter
|
||||||
|
|
||||||
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
|
|
||||||
// WgAddr is a Wireguard local address (Netbird Network IP)
|
// WgAddr is a Wireguard local address (Netbird Network IP)
|
||||||
WgAddr string
|
WgAddr string
|
||||||
@@ -109,7 +112,7 @@ type Engine struct {
|
|||||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||||
sshServer nbssh.Server
|
sshServer nbssh.Server
|
||||||
|
|
||||||
statusRecorder *nbstatus.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
|
|
||||||
@@ -126,14 +129,14 @@ type Peer struct {
|
|||||||
func NewEngine(
|
func NewEngine(
|
||||||
ctx context.Context, cancel context.CancelFunc,
|
ctx context.Context, cancel context.CancelFunc,
|
||||||
signalClient signal.Client, mgmClient mgm.Client,
|
signalClient signal.Client, mgmClient mgm.Client,
|
||||||
config *EngineConfig, statusRecorder *nbstatus.Status,
|
config *EngineConfig, statusRecorder *peer.Status,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
return &Engine{
|
return &Engine{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
signal: signalClient,
|
signal: signalClient,
|
||||||
mgmClient: mgmClient,
|
mgmClient: mgmClient,
|
||||||
peerConns: map[string]*peer.Conn{},
|
peerConns: make(map[string]*peer.Conn),
|
||||||
syncMsgMux: &sync.Mutex{},
|
syncMsgMux: &sync.Mutex{},
|
||||||
config: config,
|
config: config,
|
||||||
STUNs: []*ice.URL{},
|
STUNs: []*ice.URL{},
|
||||||
@@ -157,115 +160,89 @@ func (e *Engine) Stop() error {
|
|||||||
// Removing peers happens in the conn.CLose() asynchronously
|
// Removing peers happens in the conn.CLose() asynchronously
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
e.close()
|
||||||
if e.wgInterface.Interface != nil {
|
|
||||||
err = e.wgInterface.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.udpMux != nil {
|
|
||||||
if err := e.udpMux.Close(); err != nil {
|
|
||||||
log.Debugf("close udp mux: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.udpMuxSrflx != nil {
|
|
||||||
if err := e.udpMuxSrflx.Close(); err != nil {
|
|
||||||
log.Debugf("close server reflexive udp mux: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.udpMuxConn != nil {
|
|
||||||
if err := e.udpMuxConn.Close(); err != nil {
|
|
||||||
log.Debugf("close udp mux connection: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.udpMuxConnSrflx != nil {
|
|
||||||
if err := e.udpMuxConnSrflx.Close(); err != nil {
|
|
||||||
log.Debugf("close server reflexive udp mux connection: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isNil(e.sshServer) {
|
|
||||||
err := e.sshServer.Stop()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed stopping the SSH server: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.routeManager != nil {
|
|
||||||
e.routeManager.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.dnsServer != nil {
|
|
||||||
e.dnsServer.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("stopped Netbird Engine")
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start creates a new Wireguard tunnel interface and listens to events from Signal and Management services
|
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||||
// Connections to remote peers are not established here.
|
// Connections to remote peers are not established here.
|
||||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||||
func (e *Engine) Start() error {
|
func (e *Engine) Start() error {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
wgIfaceName := e.config.WgIfaceName
|
wgIFaceName := e.config.WgIfaceName
|
||||||
wgAddr := e.config.WgAddr
|
wgAddr := e.config.WgAddr
|
||||||
myPrivateKey := e.config.WgPrivateKey
|
myPrivateKey := e.config.WgPrivateKey
|
||||||
var err error
|
var err error
|
||||||
|
transportNet, err := e.newStdNet()
|
||||||
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error())
|
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||||
|
}
|
||||||
|
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter, transportNet)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
networkName := "udp"
|
|
||||||
if e.config.DisableIPv6Discovery {
|
|
||||||
networkName = "udp4"
|
|
||||||
}
|
|
||||||
|
|
||||||
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
e.udpMux = ice.NewUDPMuxDefault(ice.UDPMuxParams{UDPConn: e.udpMuxConn})
|
|
||||||
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx})
|
|
||||||
|
|
||||||
err = e.wgInterface.Create()
|
err = e.wgInterface.Create()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
|
||||||
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
|
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error())
|
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
|
||||||
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.wgInterface.IsUserspaceBind() {
|
||||||
|
iceBind := e.wgInterface.GetBind()
|
||||||
|
udpMux, err := iceBind.GetICEMux()
|
||||||
|
if err != nil {
|
||||||
|
e.close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
e.udpMux = udpMux.UDPMuxDefault
|
||||||
|
e.udpMuxSrflx = udpMux
|
||||||
|
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
|
||||||
|
} else {
|
||||||
|
networkName := "udp"
|
||||||
|
if e.config.DisableIPv6Discovery {
|
||||||
|
networkName = "udp4"
|
||||||
|
}
|
||||||
|
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
||||||
|
e.close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
udpMuxParams := ice.UDPMuxParams{
|
||||||
|
UDPConn: e.udpMuxConn,
|
||||||
|
Net: transportNet,
|
||||||
|
}
|
||||||
|
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
|
||||||
|
|
||||||
|
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
|
||||||
|
e.close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
|
||||||
|
}
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
||||||
|
|
||||||
if e.dnsServer == nil {
|
if e.dnsServer == nil {
|
||||||
// todo fix custom address
|
// todo fix custom address
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
@@ -381,42 +358,6 @@ func (e *Engine) removePeer(peerKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerConnectionStatus returns a connection Status or nil if peer connection wasn't found
|
|
||||||
func (e *Engine) GetPeerConnectionStatus(peerKey string) peer.ConnStatus {
|
|
||||||
conn, exists := e.peerConns[peerKey]
|
|
||||||
if exists && conn != nil {
|
|
||||||
return conn.Status()
|
|
||||||
}
|
|
||||||
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) GetPeers() []string {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
peers := []string{}
|
|
||||||
for s := range e.peerConns {
|
|
||||||
peers = append(peers, s)
|
|
||||||
}
|
|
||||||
return peers
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
|
||||||
func (e *Engine) GetConnectedPeers() []string {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
peers := []string{}
|
|
||||||
for s, conn := range e.peerConns {
|
|
||||||
if conn.Status() == peer.StatusConnected {
|
|
||||||
peers = append(peers, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return peers
|
|
||||||
}
|
|
||||||
|
|
||||||
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
|
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
|
||||||
err := s.Send(&sProto.Message{
|
err := s.Send(&sProto.Message{
|
||||||
Key: myKey.PublicKey().String(),
|
Key: myKey.PublicKey().String(),
|
||||||
@@ -433,6 +374,10 @@ func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtyp
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sendSignal(message *sProto.Message, s signal.Client) error {
|
||||||
|
return s.Send(message)
|
||||||
|
}
|
||||||
|
|
||||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||||
func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
|
func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
|
||||||
var t sProto.Body_Type
|
var t sProto.Body_Type
|
||||||
@@ -449,6 +394,10 @@ func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKe
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// indicates message support in gRPC
|
||||||
|
msg.Body.FeaturesSupported = []uint32{signal.DirectCheck}
|
||||||
|
|
||||||
err = s.Send(msg)
|
err = s.Send(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -501,7 +450,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
//nil sshServer means it has not yet been started
|
//nil sshServer means it has not yet been started
|
||||||
var err error
|
var err error
|
||||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
|
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
|
||||||
fmt.Sprintf("%s:%d", e.wgInterface.Address.IP.String(), nbssh.DefaultSSHPort))
|
fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -534,8 +483,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
if e.wgInterface.Address.String() != conf.Address {
|
if e.wgInterface.Address().String() != conf.Address {
|
||||||
oldAddr := e.wgInterface.Address.String()
|
oldAddr := e.wgInterface.Address().String()
|
||||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||||
err := e.wgInterface.UpdateAddr(conf.Address)
|
err := e.wgInterface.UpdateAddr(conf.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -552,10 +501,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(nbstatus.LocalPeerState{
|
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
|
||||||
IP: e.config.WgAddr,
|
IP: e.config.WgAddr,
|
||||||
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: iface.WireguardModuleIsLoaded(),
|
KernelInterface: iface.WireGuardModuleIsLoaded(),
|
||||||
FQDN: conf.GetFqdn(),
|
FQDN: conf.GetFqdn(),
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -637,6 +586,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
|
|
||||||
// cleanup request, most likely our peer has been deleted
|
// cleanup request, most likely our peer has been deleted
|
||||||
if networkMap.GetRemotePeersIsEmpty() {
|
if networkMap.GetRemotePeersIsEmpty() {
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
@@ -684,6 +635,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
if protoDNSConfig == nil {
|
if protoDNSConfig == nil {
|
||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
@@ -753,6 +705,21 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
|||||||
return dnsUpdate
|
return dnsUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||||
|
replacement := make([]peer.State, len(offlinePeers))
|
||||||
|
for i, offlinePeer := range offlinePeers {
|
||||||
|
log.Debugf("added offline peer %s", offlinePeer.Fqdn)
|
||||||
|
replacement[i] = peer.State{
|
||||||
|
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
|
||||||
|
PubKey: offlinePeer.GetWgPubKey(),
|
||||||
|
FQDN: offlinePeer.GetFqdn(),
|
||||||
|
ConnStatus: peer.StatusDisconnected,
|
||||||
|
ConnStatusUpdate: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||||
|
}
|
||||||
|
|
||||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
@@ -841,14 +808,6 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
|||||||
stunTurn = append(stunTurn, e.STUNs...)
|
stunTurn = append(stunTurn, e.STUNs...)
|
||||||
stunTurn = append(stunTurn, e.TURNs...)
|
stunTurn = append(stunTurn, e.TURNs...)
|
||||||
|
|
||||||
proxyConfig := proxy.Config{
|
|
||||||
RemoteKey: pubKey,
|
|
||||||
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
|
|
||||||
WgInterface: e.wgInterface,
|
|
||||||
AllowedIps: allowedIPs,
|
|
||||||
PreSharedKey: e.config.PreSharedKey,
|
|
||||||
}
|
|
||||||
|
|
||||||
// randomize connection timeout
|
// randomize connection timeout
|
||||||
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
||||||
config := peer.ConnConfig{
|
config := peer.ConnConfig{
|
||||||
@@ -860,12 +819,12 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
|||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
UDPMux: e.udpMux,
|
UDPMux: e.udpMux,
|
||||||
UDPMuxSrflx: e.udpMuxSrflx,
|
UDPMuxSrflx: e.udpMuxSrflx,
|
||||||
ProxyConfig: proxyConfig,
|
|
||||||
LocalWgPort: e.config.WgPort,
|
LocalWgPort: e.config.WgPort,
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
|
AllowedIPs: allowedIPs,
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(config, e.statusRecorder)
|
peerConn, err := peer.NewConn(config, e.wgInterface, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -890,6 +849,9 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
|||||||
peerConn.SetSignalCandidate(signalCandidate)
|
peerConn.SetSignalCandidate(signalCandidate)
|
||||||
peerConn.SetSignalOffer(signalOffer)
|
peerConn.SetSignalOffer(signalOffer)
|
||||||
peerConn.SetSignalAnswer(signalAnswer)
|
peerConn.SetSignalAnswer(signalAnswer)
|
||||||
|
peerConn.SetSendSignalMessage(func(message *sProto.Message) error {
|
||||||
|
return sendSignal(message, e.signal)
|
||||||
|
})
|
||||||
|
|
||||||
return peerConn, nil
|
return peerConn, nil
|
||||||
}
|
}
|
||||||
@@ -913,6 +875,9 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
|
||||||
|
|
||||||
conn.OnRemoteOffer(peer.OfferAnswer{
|
conn.OnRemoteOffer(peer.OfferAnswer{
|
||||||
IceCredentials: peer.IceCredentials{
|
IceCredentials: peer.IceCredentials{
|
||||||
UFrag: remoteCred.UFrag,
|
UFrag: remoteCred.UFrag,
|
||||||
@@ -926,6 +891,9 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
|
||||||
|
|
||||||
conn.OnRemoteAnswer(peer.OfferAnswer{
|
conn.OnRemoteAnswer(peer.OfferAnswer{
|
||||||
IceCredentials: peer.IceCredentials{
|
IceCredentials: peer.IceCredentials{
|
||||||
UFrag: remoteCred.UFrag,
|
UFrag: remoteCred.UFrag,
|
||||||
@@ -941,6 +909,19 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
conn.OnRemoteCandidate(candidate)
|
conn.OnRemoteCandidate(candidate)
|
||||||
|
case sProto.Body_MODE:
|
||||||
|
protoMode := msg.GetBody().GetMode()
|
||||||
|
if protoMode == nil {
|
||||||
|
return fmt.Errorf("received an empty mode message")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := conn.OnModeMessage(peer.ModeMessage{
|
||||||
|
Direct: protoMode.GetDirect(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed processing a mode message -> %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1011,6 +992,49 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
return mappedIPs
|
return mappedIPs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) close() {
|
||||||
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
|
if e.wgInterface != nil {
|
||||||
|
if err := e.wgInterface.Close(); err != nil {
|
||||||
|
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.udpMux != nil {
|
||||||
|
if err := e.udpMux.Close(); err != nil {
|
||||||
|
log.Debugf("close udp mux: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.udpMuxConn != nil {
|
||||||
|
if err := e.udpMuxConn.Close(); err != nil {
|
||||||
|
log.Debugf("close udp mux connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.udpMuxConnSrflx != nil {
|
||||||
|
if err := e.udpMuxConnSrflx.Close(); err != nil {
|
||||||
|
log.Debugf("close server reflexive udp mux connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isNil(e.sshServer) {
|
||||||
|
err := e.sshServer.Stop()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed stopping the SSH server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.routeManager != nil {
|
||||||
|
e.routeManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.dnsServer != nil {
|
||||||
|
e.dnsServer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
11
client/internal/engine_stdnet.go
Normal file
11
client/internal/engine_stdnet.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNet(e.config.IFaceBlackList)
|
||||||
|
}
|
||||||
7
client/internal/engine_stdnet_android.go
Normal file
7
client/internal/engine_stdnet_android.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNetWithDiscover(e.config.IFaceDiscover, e.config.IFaceBlackList)
|
||||||
|
}
|
||||||
@@ -3,16 +3,8 @@ package internal
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/pion/transport/v2/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -23,18 +15,29 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
mgmt "github.com/netbirdio/netbird/management/client"
|
mgmt "github.com/netbirdio/netbird/management/client"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
"github.com/netbirdio/netbird/signal/proto"
|
"github.com/netbirdio/netbird/signal/proto"
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -71,7 +74,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder("https://mgm"))
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
@@ -205,12 +208,24 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder("https://mgm"))
|
||||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
|
newNet, err := stdnet.NewNet()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
}
|
}
|
||||||
|
conn, err := net.ListenUDP("udp4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
name string
|
name string
|
||||||
@@ -389,7 +404,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder("https://mgm"))
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
@@ -439,7 +454,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(engine.GetPeers()) == 3 && engine.networkSerial == 10 {
|
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -547,8 +562,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder("https://mgm"))
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
newNet, err := stdnet.NewNet()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
input := struct {
|
input := struct {
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
@@ -712,8 +731,12 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder("https://mgm"))
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
newNet, err := stdnet.NewNet()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
@@ -846,7 +869,7 @@ loop:
|
|||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
totalConnected := 0
|
totalConnected := 0
|
||||||
for _, engine := range engines {
|
for _, engine := range engines {
|
||||||
totalConnected = totalConnected + len(engine.GetConnectedPeers())
|
totalConnected = totalConnected + getConnectedPeers(engine)
|
||||||
}
|
}
|
||||||
if totalConnected == expectedConnected {
|
if totalConnected == expectedConnected {
|
||||||
log.Infof("total connected=%d", totalConnected)
|
log.Infof("total connected=%d", totalConnected)
|
||||||
@@ -857,7 +880,7 @@ loop:
|
|||||||
}
|
}
|
||||||
// cleanup test
|
// cleanup test
|
||||||
for n, peerEngine := range engines {
|
for n, peerEngine := range engines {
|
||||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name, n)
|
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||||
errStop := peerEngine.mgmClient.Close()
|
errStop := peerEngine.mgmClient.Close()
|
||||||
if errStop != nil {
|
if errStop != nil {
|
||||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||||
@@ -905,7 +928,7 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
|||||||
expectedOutput: []string{"1.1.1.1", "8.8.8.8/" + testingIP},
|
expectedOutput: []string{"1.1.1.1", "8.8.8.8/" + testingIP},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Only Interface Name Should Return Nil",
|
name: "Only Interface name Should Return Nil",
|
||||||
inputBlacklistInterface: defaultInterfaceBlacklist,
|
inputBlacklistInterface: defaultInterfaceBlacklist,
|
||||||
inputMapList: []string{testingInterface},
|
inputMapList: []string{testingInterface},
|
||||||
expectedOutput: nil,
|
expectedOutput: nil,
|
||||||
@@ -977,7 +1000,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
WgPort: wgPort,
|
WgPort: wgPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, nbstatus.NewRecorder()), nil
|
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startSignal() (*grpc.Server, string, error) {
|
func startSignal() (*grpc.Server, string, error) {
|
||||||
@@ -1044,3 +1067,23 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
|||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
return s, lis.Addr().String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||||
|
func getConnectedPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
i := 0
|
||||||
|
for _, conn := range e.peerConns {
|
||||||
|
if conn.Status() == peer.StatusConnected {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
return len(e.peerConns)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,37 +2,26 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
// IsLoginRequired check that the server is support SSO or not
|
||||||
// validate our peer's Wireguard PRIVATE key
|
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) {
|
||||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
return false, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var mgmTlsEnabled bool
|
|
||||||
if config.ManagementURL.Scheme == "https" {
|
|
||||||
mgmTlsEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to the Management service %s %v", config.ManagementURL.String(), err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = mgmClient.Close()
|
err = mgmClient.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -42,47 +31,84 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(sshKey))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||||
|
if isLoginNeeded(err) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login or register the client
|
||||||
|
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
||||||
|
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
err = mgmClient.Close()
|
||||||
|
if err != nil {
|
||||||
|
cStatus, ok := status.FromError(err)
|
||||||
|
if !ok || ok && cStatus.Code() != codes.Canceled {
|
||||||
|
log.Warnf("failed to close the Management service client, err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = loginPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed logging-in peer on Management Service : %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Infof("peer has successfully logged-in to the Management service %s", config.ManagementURL.String())
|
|
||||||
|
|
||||||
err = mgmClient.Close()
|
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||||
if err != nil {
|
if isRegistrationNeeded(err) {
|
||||||
log.Errorf("failed to close the Management service client: %v", err)
|
log.Debugf("peer registration required")
|
||||||
|
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow.
|
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||||
func loginPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
// validate our peer's Wireguard PRIVATE key
|
||||||
sysInfo := system.GetInfo(ctx)
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||||
loginResp, err := client.Login(serverPublicKey, sysInfo, pubSSHKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||||
log.Debugf("peer registration required")
|
return nil, err
|
||||||
return registerPeer(ctx, serverPublicKey, client, setupKey, jwtToken, pubSSHKey)
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return loginResp, nil
|
var mgmTlsEnabled bool
|
||||||
|
if mgmURL.Scheme == "https" {
|
||||||
|
mgmTlsEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mgmClient, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) {
|
||||||
|
serverKey, err := mgmClient.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sysInfo := system.GetInfo(ctx)
|
||||||
|
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
||||||
|
return serverKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
@@ -105,3 +131,31 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
|||||||
|
|
||||||
return loginResp, nil
|
return loginResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isLoginNeeded(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRegistrationNeeded(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.Code() == codes.PermissionDenied {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,15 +35,6 @@ type DeviceAuthInfo struct {
|
|||||||
Interval int `json:"interval"`
|
Interval int `json:"interval"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenInfo holds information of issued access token
|
|
||||||
type TokenInfo struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
IDToken string `json:"id_token"`
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// HostedGrantType grant type for device flow on Hosted
|
// HostedGrantType grant type for device flow on Hosted
|
||||||
const (
|
const (
|
||||||
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
@@ -52,16 +43,7 @@ const (
|
|||||||
|
|
||||||
// Hosted client
|
// Hosted client
|
||||||
type Hosted struct {
|
type Hosted struct {
|
||||||
// Hosted API Audience for validation
|
providerConfig ProviderConfig
|
||||||
Audience string
|
|
||||||
// Hosted Native application client id
|
|
||||||
ClientID string
|
|
||||||
// Hosted Native application request scope
|
|
||||||
Scope string
|
|
||||||
// TokenEndpoint to request access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// DeviceAuthEndpoint to request device authorization code
|
|
||||||
DeviceAuthEndpoint string
|
|
||||||
|
|
||||||
HTTPClient HTTPClient
|
HTTPClient HTTPClient
|
||||||
}
|
}
|
||||||
@@ -70,7 +52,7 @@ type Hosted struct {
|
|||||||
type RequestDeviceCodePayload struct {
|
type RequestDeviceCodePayload struct {
|
||||||
Audience string `json:"audience"`
|
Audience string `json:"audience"`
|
||||||
ClientID string `json:"client_id"`
|
ClientID string `json:"client_id"`
|
||||||
Scope string `json:"scope"`
|
Scope string `json:"scope"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenRequestPayload used for requesting the auth0 token
|
// TokenRequestPayload used for requesting the auth0 token
|
||||||
@@ -93,8 +75,26 @@ type Claims struct {
|
|||||||
Audience interface{} `json:"aud"`
|
Audience interface{} `json:"aud"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenInfo holds information of issued access token
|
||||||
|
type TokenInfo struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
UseIDToken bool `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
||||||
|
func (t TokenInfo) GetTokenToUse() string {
|
||||||
|
if t.UseIDToken {
|
||||||
|
return t.IDToken
|
||||||
|
}
|
||||||
|
return t.AccessToken
|
||||||
|
}
|
||||||
|
|
||||||
// NewHostedDeviceFlow returns an Hosted OAuth client
|
// NewHostedDeviceFlow returns an Hosted OAuth client
|
||||||
func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted {
|
func NewHostedDeviceFlow(config ProviderConfig) *Hosted {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
@@ -104,27 +104,23 @@ func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Hosted{
|
return &Hosted{
|
||||||
Audience: audience,
|
providerConfig: config,
|
||||||
ClientID: clientID,
|
HTTPClient: httpClient,
|
||||||
Scope: "openid",
|
|
||||||
TokenEndpoint: tokenEndpoint,
|
|
||||||
HTTPClient: httpClient,
|
|
||||||
DeviceAuthEndpoint: deviceAuthEndpoint,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientID returns the provider client id
|
// GetClientID returns the provider client id
|
||||||
func (h *Hosted) GetClientID(ctx context.Context) string {
|
func (h *Hosted) GetClientID(ctx context.Context) string {
|
||||||
return h.ClientID
|
return h.providerConfig.ClientID
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestDeviceCode requests a device code login flow information from Hosted
|
// RequestDeviceCode requests a device code login flow information from Hosted
|
||||||
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
form.Add("client_id", h.ClientID)
|
form.Add("client_id", h.providerConfig.ClientID)
|
||||||
form.Add("audience", h.Audience)
|
form.Add("audience", h.providerConfig.Audience)
|
||||||
form.Add("scope", h.Scope)
|
form.Add("scope", h.providerConfig.Scope)
|
||||||
req, err := http.NewRequest("POST", h.DeviceAuthEndpoint,
|
req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint,
|
||||||
strings.NewReader(form.Encode()))
|
strings.NewReader(form.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
||||||
@@ -157,10 +153,10 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
|||||||
|
|
||||||
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
|
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
form.Add("client_id", h.ClientID)
|
form.Add("client_id", h.providerConfig.ClientID)
|
||||||
form.Add("grant_type", HostedGrantType)
|
form.Add("grant_type", HostedGrantType)
|
||||||
form.Add("device_code", info.DeviceCode)
|
form.Add("device_code", info.DeviceCode)
|
||||||
req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode()))
|
req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
||||||
}
|
}
|
||||||
@@ -225,18 +221,20 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
|||||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = isValidAccessToken(tokenResponse.AccessToken, h.Audience)
|
|
||||||
if err != nil {
|
|
||||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenInfo := TokenInfo{
|
tokenInfo := TokenInfo{
|
||||||
AccessToken: tokenResponse.AccessToken,
|
AccessToken: tokenResponse.AccessToken,
|
||||||
TokenType: tokenResponse.TokenType,
|
TokenType: tokenResponse.TokenType,
|
||||||
RefreshToken: tokenResponse.RefreshToken,
|
RefreshToken: tokenResponse.RefreshToken,
|
||||||
IDToken: tokenResponse.IDToken,
|
IDToken: tokenResponse.IDToken,
|
||||||
ExpiresIn: tokenResponse.ExpiresIn,
|
ExpiresIn: tokenResponse.ExpiresIn,
|
||||||
|
UseIDToken: h.providerConfig.UseIDToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience)
|
||||||
|
if err != nil {
|
||||||
|
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return tokenInfo, err
|
return tokenInfo, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,14 +3,15 @@ package internal
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockHTTPClient struct {
|
type mockHTTPClient struct {
|
||||||
@@ -113,12 +114,15 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hosted := Hosted{
|
hosted := Hosted{
|
||||||
Audience: expectedAudience,
|
providerConfig: ProviderConfig{
|
||||||
ClientID: expectedClientID,
|
Audience: expectedAudience,
|
||||||
Scope: expectedScope,
|
ClientID: expectedClientID,
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
Scope: expectedScope,
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
HTTPClient: &httpClient,
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
|
UseIDToken: false,
|
||||||
|
},
|
||||||
|
HTTPClient: &httpClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
authInfo, err := hosted.RequestDeviceCode(context.TODO())
|
authInfo, err := hosted.RequestDeviceCode(context.TODO())
|
||||||
@@ -275,12 +279,15 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hosted := Hosted{
|
hosted := Hosted{
|
||||||
Audience: testCase.inputAudience,
|
providerConfig: ProviderConfig{
|
||||||
ClientID: clientID,
|
Audience: testCase.inputAudience,
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
ClientID: clientID,
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
HTTPClient: &httpClient,
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
}
|
Scope: "openid",
|
||||||
|
UseIDToken: false,
|
||||||
|
},
|
||||||
|
HTTPClient: &httpClient}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -2,18 +2,21 @@ package peer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnConfig is a peer Connection configuration
|
// ConnConfig is a peer Connection configuration
|
||||||
@@ -34,14 +37,14 @@ type ConnConfig struct {
|
|||||||
|
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
|
|
||||||
ProxyConfig proxy.Config
|
|
||||||
|
|
||||||
UDPMux ice.UDPMux
|
UDPMux ice.UDPMux
|
||||||
UDPMuxSrflx ice.UniversalUDPMux
|
UDPMuxSrflx ice.UniversalUDPMux
|
||||||
|
|
||||||
LocalWgPort int
|
LocalWgPort int
|
||||||
|
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
|
|
||||||
|
AllowedIPs string
|
||||||
}
|
}
|
||||||
|
|
||||||
// OfferAnswer represents a session establishment offer or answer
|
// OfferAnswer represents a session establishment offer or answer
|
||||||
@@ -69,8 +72,9 @@ type Conn struct {
|
|||||||
// signalCandidate is a handler function to signal remote peer about local connection candidate
|
// signalCandidate is a handler function to signal remote peer about local connection candidate
|
||||||
signalCandidate func(candidate ice.Candidate) error
|
signalCandidate func(candidate ice.Candidate) error
|
||||||
// signalOffer is a handler function to signal remote peer our connection offer (credentials)
|
// signalOffer is a handler function to signal remote peer our connection offer (credentials)
|
||||||
signalOffer func(OfferAnswer) error
|
signalOffer func(OfferAnswer) error
|
||||||
signalAnswer func(OfferAnswer) error
|
signalAnswer func(OfferAnswer) error
|
||||||
|
sendSignalMessage func(message *sProto.Message) error
|
||||||
|
|
||||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||||
remoteOffersCh chan OfferAnswer
|
remoteOffersCh chan OfferAnswer
|
||||||
@@ -83,9 +87,26 @@ type Conn struct {
|
|||||||
agent *ice.Agent
|
agent *ice.Agent
|
||||||
status ConnStatus
|
status ConnStatus
|
||||||
|
|
||||||
statusRecorder *nbStatus.Status
|
statusRecorder *Status
|
||||||
|
|
||||||
proxy proxy.Proxy
|
proxy proxy.Proxy
|
||||||
|
remoteModeCh chan ModeMessage
|
||||||
|
meta meta
|
||||||
|
|
||||||
|
wgIface *iface.WGIface
|
||||||
|
adapter iface.TunAdapter
|
||||||
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
|
}
|
||||||
|
|
||||||
|
// meta holds meta information about a connection
|
||||||
|
type meta struct {
|
||||||
|
protoSupport signal.FeaturesSupport
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModeMessage represents a connection mode chosen by the peer
|
||||||
|
type ModeMessage struct {
|
||||||
|
// Direct indicates that it decided to use a direct connection
|
||||||
|
Direct bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConf returns the connection config
|
// GetConf returns the connection config
|
||||||
@@ -100,7 +121,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
|
|||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
// To establish a connection run Conn.Open
|
// To establish a connection run Conn.Open
|
||||||
func NewConn(config ConnConfig, statusRecorder *nbStatus.Status) (*Conn, error) {
|
func NewConn(config ConnConfig, wgIface *iface.WGIface, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
config: config,
|
config: config,
|
||||||
mu: sync.Mutex{},
|
mu: sync.Mutex{},
|
||||||
@@ -109,53 +130,35 @@ func NewConn(config ConnConfig, statusRecorder *nbStatus.Status) (*Conn, error)
|
|||||||
remoteOffersCh: make(chan OfferAnswer),
|
remoteOffersCh: make(chan OfferAnswer),
|
||||||
remoteAnswerCh: make(chan OfferAnswer),
|
remoteAnswerCh: make(chan OfferAnswer),
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
remoteModeCh: make(chan ModeMessage, 1),
|
||||||
|
adapter: adapter,
|
||||||
|
iFaceDiscover: iFaceDiscover,
|
||||||
|
wgIface: wgIface,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// interfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
|
|
||||||
// to avoid building tunnel over them
|
|
||||||
func interfaceFilter(blackList []string) func(string) bool {
|
|
||||||
|
|
||||||
return func(iFace string) bool {
|
|
||||||
for _, s := range blackList {
|
|
||||||
if strings.HasPrefix(iFace, s) {
|
|
||||||
log.Debugf("ignoring interface %s - it is not allowed", iFace)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// look for unlisted WireGuard interfaces
|
|
||||||
wg, err := wgctrl.New()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("trying to create a wgctrl client failed with: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err := wg.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = wg.Device(iFace)
|
|
||||||
return err != nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) reCreateAgent() error {
|
func (conn *Conn) reCreateAgent() error {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
failedTimeout := 6 * time.Second
|
failedTimeout := 6 * time.Second
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
transportNet, err := conn.newStdNet()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||||
|
}
|
||||||
agentConfig := &ice.AgentConfig{
|
agentConfig := &ice.AgentConfig{
|
||||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||||
Urls: conn.config.StunTurn,
|
Urls: conn.config.StunTurn,
|
||||||
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
|
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
|
||||||
FailedTimeout: &failedTimeout,
|
FailedTimeout: &failedTimeout,
|
||||||
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList),
|
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
||||||
UDPMux: conn.config.UDPMux,
|
UDPMux: conn.config.UDPMux,
|
||||||
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
UDPMuxSrflx: conn.config.UDPMuxSrflx,
|
||||||
NAT1To1IPs: conn.config.NATExternalIPs,
|
NAT1To1IPs: conn.config.NATExternalIPs,
|
||||||
|
Net: transportNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.config.DisableIPv6Discovery {
|
if conn.config.DisableIPv6Discovery {
|
||||||
@@ -192,11 +195,11 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
func (conn *Conn) Open() error {
|
func (conn *Conn) Open() error {
|
||||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||||
|
|
||||||
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
|
peerState := State{PubKey: conn.config.Key}
|
||||||
|
|
||||||
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
|
peerState.IP = strings.Split(conn.config.AllowedIPs, "/")[0]
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
peerState.ConnStatusUpdate = time.Now()
|
||||||
peerState.ConnStatus = conn.status.String()
|
peerState.ConnStatus = conn.status
|
||||||
|
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -252,9 +255,9 @@ func (conn *Conn) Open() error {
|
|||||||
defer conn.notifyDisconnected()
|
defer conn.notifyDisconnected()
|
||||||
conn.mu.Unlock()
|
conn.mu.Unlock()
|
||||||
|
|
||||||
peerState = nbStatus.PeerState{PubKey: conn.config.Key}
|
peerState = State{PubKey: conn.config.Key}
|
||||||
|
|
||||||
peerState.ConnStatus = conn.status.String()
|
peerState.ConnStatus = conn.status
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
peerState.ConnStatusUpdate = time.Now()
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -291,7 +294,7 @@ func (conn *Conn) Open() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.proxy.Type() == proxy.TypeNoProxy {
|
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
|
||||||
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
|
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
|
||||||
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
|
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
|
||||||
// direct Wireguard connection
|
// direct Wireguard connection
|
||||||
@@ -312,38 +315,75 @@ func (conn *Conn) Open() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// useProxy determines whether a direct connection (without a go proxy) is possible
|
// useProxy determines whether a direct connection (without a go proxy) is possible
|
||||||
// There are 3 cases: one of the peers has a public IP or both peers are in the same private network
|
//
|
||||||
|
// There are 3 cases:
|
||||||
|
//
|
||||||
|
// * When neither candidate is from hard nat and one of the peers has a public IP
|
||||||
|
//
|
||||||
|
// * both peers are in the same private network
|
||||||
|
//
|
||||||
|
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
|
||||||
|
//
|
||||||
// Please note, that this check happens when peers were already able to ping each other using ICE layer.
|
// Please note, that this check happens when peers were already able to ping each other using ICE layer.
|
||||||
func shouldUseProxy(pair *ice.CandidatePair) bool {
|
func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
|
||||||
remoteIP := net.ParseIP(pair.Remote.Address())
|
|
||||||
myIp := net.ParseIP(pair.Local.Address())
|
|
||||||
remoteIsPublic := IsPublicIP(remoteIP)
|
|
||||||
myIsPublic := IsPublicIP(myIp)
|
|
||||||
|
|
||||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
if !isRelayCandidate(pair.Local) && userspaceBind {
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
//one of the hosts has a public IP
|
|
||||||
if remoteIsPublic && pair.Remote.Type() == ice.CandidateTypeHost {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if myIsPublic && pair.Local.Type() == ice.CandidateTypeHost {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if pair.Local.Type() == ice.CandidateTypeHost && pair.Remote.Type() == ice.CandidateTypeHost {
|
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
|
||||||
if !remoteIsPublic && !myIsPublic {
|
return false
|
||||||
//both hosts are in the same private network
|
}
|
||||||
return false
|
|
||||||
}
|
if !isHardNATCandidate(pair.Remote) && isHostCandidateWithPublicIP(pair.Local) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsPublicIP indicates whether IP is public or not.
|
func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
|
||||||
func IsPublicIP(ip net.IP) bool {
|
|
||||||
|
localIPStr, _, err := net.SplitHostPort(pair.Local.Address())
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
remoteIPStr, _, err := net.SplitHostPort(pair.Remote.Address())
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
localIP := net.ParseIP(localIPStr)
|
||||||
|
remoteIP := net.ParseIP(remoteIPStr)
|
||||||
|
if localIP == nil || remoteIP == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// only consider /16 networks
|
||||||
|
mask := net.IPMask{255, 255, 0, 0}
|
||||||
|
return localIP.Mask(mask).Equal(remoteIP.Mask(mask))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||||
|
return candidate.Type() == ice.CandidateTypeRelay
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHardNATCandidate(candidate ice.Candidate) bool {
|
||||||
|
return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHostCandidateWithPublicIP(candidate ice.Candidate) bool {
|
||||||
|
return candidate.Type() == ice.CandidateTypeHost && isPublicIP(candidate.Address())
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHostCandidateWithPrivateIP(candidate ice.Candidate) bool {
|
||||||
|
return candidate.Type() == ice.CandidateTypeHost && !isPublicIP(candidate.Address())
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPublicIP(address string) bool {
|
||||||
|
ip := net.ParseIP(address)
|
||||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
|
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -361,16 +401,8 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
|
peerState := State{PubKey: conn.config.Key}
|
||||||
useProxy := shouldUseProxy(pair)
|
p := conn.getProxyWithMessageExchange(pair, remoteWgPort)
|
||||||
var p proxy.Proxy
|
|
||||||
if useProxy {
|
|
||||||
p = proxy.NewWireguardProxy(conn.config.ProxyConfig)
|
|
||||||
peerState.Direct = false
|
|
||||||
} else {
|
|
||||||
p = proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
|
|
||||||
peerState.Direct = true
|
|
||||||
}
|
|
||||||
conn.proxy = p
|
conn.proxy = p
|
||||||
err = p.Start(remoteConn)
|
err = p.Start(remoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -379,13 +411,14 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
|||||||
|
|
||||||
conn.status = StatusConnected
|
conn.status = StatusConnected
|
||||||
|
|
||||||
peerState.ConnStatus = conn.status.String()
|
peerState.ConnStatus = conn.status
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
peerState.ConnStatusUpdate = time.Now()
|
||||||
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
||||||
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
||||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||||
peerState.Relayed = true
|
peerState.Relayed = true
|
||||||
}
|
}
|
||||||
|
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy
|
||||||
|
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -395,6 +428,66 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
|
||||||
|
useProxy := shouldUseProxy(pair, conn.wgIface.IsUserspaceBind())
|
||||||
|
localDirectMode := !useProxy
|
||||||
|
remoteDirectMode := localDirectMode
|
||||||
|
|
||||||
|
if conn.meta.protoSupport.DirectCheck {
|
||||||
|
go conn.sendLocalDirectMode(localDirectMode)
|
||||||
|
// will block until message received or timeout
|
||||||
|
remoteDirectMode = conn.receiveRemoteDirectMode()
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.wgIface.IsUserspaceBind() && localDirectMode {
|
||||||
|
return proxy.NewNoProxy(conn.config.ProxyConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if localDirectMode && remoteDirectMode {
|
||||||
|
//wgInterface *iface.WGIface, remoteKey string, allowedIps string, preSharedKey *wgtypes.Key, remoteWgPort int)
|
||||||
|
return proxy.NewDirectNoProxy(conn.wgIface, conn.config.Key, conn.config.AllowedIPs, remoteWgPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
|
||||||
|
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) sendLocalDirectMode(localMode bool) {
|
||||||
|
// todo what happens when we couldn't deliver this message?
|
||||||
|
// we could retry, etc but there is no guarantee
|
||||||
|
|
||||||
|
err := conn.sendSignalMessage(&sProto.Message{
|
||||||
|
Key: conn.config.LocalKey,
|
||||||
|
RemoteKey: conn.config.Key,
|
||||||
|
Body: &sProto.Body{
|
||||||
|
Type: sProto.Body_MODE,
|
||||||
|
Mode: &sProto.Mode{
|
||||||
|
Direct: &localMode,
|
||||||
|
},
|
||||||
|
NetBirdVersion: version.NetbirdVersion(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to send local proxy mode to remote peer %s, error: %s", conn.config.Key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) receiveRemoteDirectMode() bool {
|
||||||
|
timeout := time.Second
|
||||||
|
timer := time.NewTimer(timeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case receivedMSG := <-conn.remoteModeCh:
|
||||||
|
return receivedMSG.Direct
|
||||||
|
case <-timer.C:
|
||||||
|
// we didn't receive a message from remote so we assume that it supports the direct mode to keep the old behaviour
|
||||||
|
log.Debugf("timeout after %s while waiting for remote direct mode message from remote peer %s",
|
||||||
|
timeout, conn.config.Key)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// cleanup closes all open resources and sets status to StatusDisconnected
|
// cleanup closes all open resources and sets status to StatusDisconnected
|
||||||
func (conn *Conn) cleanup() error {
|
func (conn *Conn) cleanup() error {
|
||||||
log.Debugf("trying to cleanup %s", conn.config.Key)
|
log.Debugf("trying to cleanup %s", conn.config.Key)
|
||||||
@@ -424,8 +517,8 @@ func (conn *Conn) cleanup() error {
|
|||||||
|
|
||||||
conn.status = StatusDisconnected
|
conn.status = StatusDisconnected
|
||||||
|
|
||||||
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
|
peerState := State{PubKey: conn.config.Key}
|
||||||
peerState.ConnStatus = conn.status.String()
|
peerState.ConnStatus = conn.status
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
peerState.ConnStatusUpdate = time.Now()
|
||||||
|
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
@@ -455,6 +548,11 @@ func (conn *Conn) SetSignalCandidate(handler func(candidate ice.Candidate) error
|
|||||||
conn.signalCandidate = handler
|
conn.signalCandidate = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSendSignalMessage sets a handler function to be triggered by Conn when there is new message to send via signal
|
||||||
|
func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) error) {
|
||||||
|
conn.sendSignalMessage = handler
|
||||||
|
}
|
||||||
|
|
||||||
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
|
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
|
||||||
// and then signals them to the remote peer
|
// and then signals them to the remote peer
|
||||||
func (conn *Conn) onICECandidate(candidate ice.Candidate) {
|
func (conn *Conn) onICECandidate(candidate ice.Candidate) {
|
||||||
@@ -496,7 +594,7 @@ func (conn *Conn) sendAnswer() error {
|
|||||||
err = conn.signalAnswer(OfferAnswer{
|
err = conn.signalAnswer(OfferAnswer{
|
||||||
IceCredentials: IceCredentials{localUFrag, localPwd},
|
IceCredentials: IceCredentials{localUFrag, localPwd},
|
||||||
WgListenPort: conn.config.LocalWgPort,
|
WgListenPort: conn.config.LocalWgPort,
|
||||||
Version: system.NetbirdVersion(),
|
Version: version.NetbirdVersion(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -517,7 +615,7 @@ func (conn *Conn) sendOffer() error {
|
|||||||
err = conn.signalOffer(OfferAnswer{
|
err = conn.signalOffer(OfferAnswer{
|
||||||
IceCredentials: IceCredentials{localUFrag, localPwd},
|
IceCredentials: IceCredentials{localUFrag, localPwd},
|
||||||
WgListenPort: conn.config.LocalWgPort,
|
WgListenPort: conn.config.LocalWgPort,
|
||||||
Version: system.NetbirdVersion(),
|
Version: version.NetbirdVersion(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -609,3 +707,19 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
|
|||||||
func (conn *Conn) GetKey() string {
|
func (conn *Conn) GetKey() string {
|
||||||
return conn.config.Key
|
return conn.config.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnModeMessage unmarshall the payload message and send it to the mode message channel
|
||||||
|
func (conn *Conn) OnModeMessage(message ModeMessage) error {
|
||||||
|
select {
|
||||||
|
case conn.remoteModeCh <- message:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unable to process mode message: channel busy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProtoSupportMeta register supported proto message in the connection metadata
|
||||||
|
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
|
||||||
|
protoSupport := signal.ParseFeaturesSupported(support)
|
||||||
|
conn.meta.protoSupport = protoSupport
|
||||||
|
}
|
||||||
|
|||||||
29
client/internal/peer/conn_status.go
Normal file
29
client/internal/peer/conn_status.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// StatusConnected indicate the peer is in connected state
|
||||||
|
StatusConnected ConnStatus = iota
|
||||||
|
// StatusConnecting indicate the peer is in connecting state
|
||||||
|
StatusConnecting
|
||||||
|
// StatusDisconnected indicate the peer is in disconnected state
|
||||||
|
StatusDisconnected
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnStatus describe the status of a peer's connection
|
||||||
|
type ConnStatus int
|
||||||
|
|
||||||
|
func (s ConnStatus) String() string {
|
||||||
|
switch s {
|
||||||
|
case StatusConnecting:
|
||||||
|
return "Connecting"
|
||||||
|
case StatusConnected:
|
||||||
|
return "Connected"
|
||||||
|
case StatusDisconnected:
|
||||||
|
return "Disconnected"
|
||||||
|
default:
|
||||||
|
log.Errorf("unknown status: %d", s)
|
||||||
|
return "INVALID_PEER_CONNECTION_STATUS"
|
||||||
|
}
|
||||||
|
}
|
||||||
27
client/internal/peer/conn_status_test.go
Normal file
27
client/internal/peer/conn_status_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/magiconair/properties/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnStatus_String(t *testing.T) {
|
||||||
|
|
||||||
|
tables := []struct {
|
||||||
|
name string
|
||||||
|
status ConnStatus
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"StatusConnected", StatusConnected, "Connected"},
|
||||||
|
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
|
||||||
|
{"StatusConnecting", StatusConnecting, "Connecting"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range tables {
|
||||||
|
t.Run(table.name, func(t *testing.T) {
|
||||||
|
got := table.status.String()
|
||||||
|
assert.Equal(t, got, table.want, "they should be equal")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,14 +1,18 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/magiconair/properties/assert"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/pion/ice/v2"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/magiconair/properties/assert"
|
||||||
|
"github.com/pion/ice/v2"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
sproto "github.com/netbirdio/netbird/signal/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
var connConf = ConnConfig{
|
var connConf = ConnConfig{
|
||||||
@@ -25,7 +29,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
|
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
|
||||||
"Tailscale", "tailscale"}
|
"Tailscale", "tailscale"}
|
||||||
|
|
||||||
filter := interfaceFilter(ignore)
|
filter := stdnet.InterfaceFilter(ignore)
|
||||||
|
|
||||||
for _, s := range ignore {
|
for _, s := range ignore {
|
||||||
assert.Equal(t, filter(s), false)
|
assert.Equal(t, filter(s), false)
|
||||||
@@ -34,7 +38,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_GetKey(t *testing.T) {
|
func TestConn_GetKey(t *testing.T) {
|
||||||
conn, err := NewConn(connConf, nil)
|
conn, err := NewConn(connConf, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -46,7 +50,7 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -80,7 +84,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -113,7 +117,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
func TestConn_Status(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -140,7 +144,7 @@ func TestConn_Status(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_Close(t *testing.T) {
|
func TestConn_Close(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -165,3 +169,274 @@ func TestConn_Close(t *testing.T) {
|
|||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockICECandidate struct {
|
||||||
|
ice.Candidate
|
||||||
|
AddressFunc func() string
|
||||||
|
TypeFunc func() ice.CandidateType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Address mocks and overwrite ice.Candidate Address method
|
||||||
|
func (m *mockICECandidate) Address() string {
|
||||||
|
if m.AddressFunc != nil {
|
||||||
|
return m.AddressFunc()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type mocks and overwrite ice.Candidate Type method
|
||||||
|
func (m *mockICECandidate) Type() ice.CandidateType {
|
||||||
|
if m.TypeFunc != nil {
|
||||||
|
return m.TypeFunc()
|
||||||
|
}
|
||||||
|
return ice.CandidateTypeUnspecified
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConn_ShouldUseProxy(t *testing.T) {
|
||||||
|
publicHostCandidate := &mockICECandidate{
|
||||||
|
AddressFunc: func() string {
|
||||||
|
return "8.8.8.8"
|
||||||
|
},
|
||||||
|
TypeFunc: func() ice.CandidateType {
|
||||||
|
return ice.CandidateTypeHost
|
||||||
|
},
|
||||||
|
}
|
||||||
|
privateHostCandidate := &mockICECandidate{
|
||||||
|
AddressFunc: func() string {
|
||||||
|
return "10.0.0.1:44576"
|
||||||
|
},
|
||||||
|
TypeFunc: func() ice.CandidateType {
|
||||||
|
return ice.CandidateTypeHost
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
srflxCandidate := &mockICECandidate{
|
||||||
|
AddressFunc: func() string {
|
||||||
|
return "1.1.1.1"
|
||||||
|
},
|
||||||
|
TypeFunc: func() ice.CandidateType {
|
||||||
|
return ice.CandidateTypeServerReflexive
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
prflxCandidate := &mockICECandidate{
|
||||||
|
AddressFunc: func() string {
|
||||||
|
return "1.1.1.1"
|
||||||
|
},
|
||||||
|
TypeFunc: func() ice.CandidateType {
|
||||||
|
return ice.CandidateTypePeerReflexive
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
relayCandidate := &mockICECandidate{
|
||||||
|
AddressFunc: func() string {
|
||||||
|
return "1.1.1.1"
|
||||||
|
},
|
||||||
|
TypeFunc: func() ice.CandidateType {
|
||||||
|
return ice.CandidateTypeRelay
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
candatePair *ice.CandidatePair
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Use Proxy When Local Candidate Is Relay",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: relayCandidate,
|
||||||
|
Remote: privateHostCandidate,
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Use Proxy When Remote Candidate Is Relay",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: privateHostCandidate,
|
||||||
|
Remote: relayCandidate,
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Use Proxy When Local Candidate Is Peer Reflexive",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: prflxCandidate,
|
||||||
|
Remote: privateHostCandidate,
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Use Proxy When Remote Candidate Is Peer Reflexive",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: privateHostCandidate,
|
||||||
|
Remote: prflxCandidate,
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Don't Use Proxy When Local Candidate Is Public And Remote Is Private",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: publicHostCandidate,
|
||||||
|
Remote: privateHostCandidate,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Don't Use Proxy When Remote Candidate Is Public And Local Is Private",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: privateHostCandidate,
|
||||||
|
Remote: publicHostCandidate,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Don't Use Proxy When Local Candidate is Public And Remote Is Server Reflexive",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: publicHostCandidate,
|
||||||
|
Remote: srflxCandidate,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Don't Use Proxy When Remote Candidate is Public And Local Is Server Reflexive",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: srflxCandidate,
|
||||||
|
Remote: publicHostCandidate,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Don't Use Proxy When Both Candidates Are Public",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: publicHostCandidate,
|
||||||
|
Remote: publicHostCandidate,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Don't Use Proxy When Both Candidates Are Private",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: privateHostCandidate,
|
||||||
|
Remote: privateHostCandidate,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
result := shouldUseProxy(testCase.candatePair, false)
|
||||||
|
if result != testCase.expected {
|
||||||
|
t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetProxyWithMessageExchange(t *testing.T) {
|
||||||
|
publicHostCandidate := &mockICECandidate{
|
||||||
|
AddressFunc: func() string {
|
||||||
|
return "8.8.8.8"
|
||||||
|
},
|
||||||
|
TypeFunc: func() ice.CandidateType {
|
||||||
|
return ice.CandidateTypeHost
|
||||||
|
},
|
||||||
|
}
|
||||||
|
relayCandidate := &mockICECandidate{
|
||||||
|
AddressFunc: func() string {
|
||||||
|
return "1.1.1.1"
|
||||||
|
},
|
||||||
|
TypeFunc: func() ice.CandidateType {
|
||||||
|
return ice.CandidateTypeRelay
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
candatePair *ice.CandidatePair
|
||||||
|
inputDirectModeSupport bool
|
||||||
|
inputRemoteModeMessage bool
|
||||||
|
expected proxy.Type
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Should Result In Using Wireguard Proxy When Local Eval Is Use Proxy",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: relayCandidate,
|
||||||
|
Remote: publicHostCandidate,
|
||||||
|
},
|
||||||
|
inputDirectModeSupport: true,
|
||||||
|
inputRemoteModeMessage: true,
|
||||||
|
expected: proxy.TypeWireGuard,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: publicHostCandidate,
|
||||||
|
Remote: publicHostCandidate,
|
||||||
|
},
|
||||||
|
inputDirectModeSupport: true,
|
||||||
|
inputRemoteModeMessage: false,
|
||||||
|
expected: proxy.TypeWireGuard,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: relayCandidate,
|
||||||
|
Remote: publicHostCandidate,
|
||||||
|
},
|
||||||
|
inputDirectModeSupport: false,
|
||||||
|
inputRemoteModeMessage: false,
|
||||||
|
expected: proxy.TypeWireGuard,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: publicHostCandidate,
|
||||||
|
Remote: publicHostCandidate,
|
||||||
|
},
|
||||||
|
inputDirectModeSupport: false,
|
||||||
|
inputRemoteModeMessage: false,
|
||||||
|
expected: proxy.TypeDirectNoProxy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy",
|
||||||
|
candatePair: &ice.CandidatePair{
|
||||||
|
Local: publicHostCandidate,
|
||||||
|
Remote: publicHostCandidate,
|
||||||
|
},
|
||||||
|
inputDirectModeSupport: true,
|
||||||
|
inputRemoteModeMessage: true,
|
||||||
|
expected: proxy.TypeDirectNoProxy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
g := errgroup.Group{}
|
||||||
|
conn, err := NewConn(connConf, nil, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
conn.meta.protoSupport.DirectCheck = testCase.inputDirectModeSupport
|
||||||
|
conn.SetSendSignalMessage(func(message *sproto.Message) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
g.Go(func() error {
|
||||||
|
return conn.OnModeMessage(ModeMessage{
|
||||||
|
Direct: testCase.inputRemoteModeMessage,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
resultProxy := conn.getProxyWithMessageExchange(testCase.candatePair, 1000)
|
||||||
|
|
||||||
|
err = g.Wait()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if resultProxy.Type() != testCase.expected {
|
||||||
|
t.Errorf("result didn't match expected value: Expected: %s, Got: %s", testCase.expected, resultProxy.Type())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
11
client/internal/peer/listener.go
Normal file
11
client/internal/peer/listener.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
// Listener is a callback type about the NetBird network connection state
|
||||||
|
type Listener interface {
|
||||||
|
OnConnected()
|
||||||
|
OnDisconnected()
|
||||||
|
OnConnecting()
|
||||||
|
OnDisconnecting()
|
||||||
|
OnAddressChanged(string, string)
|
||||||
|
OnPeersListChanged(int)
|
||||||
|
}
|
||||||
142
client/internal/peer/notifier.go
Normal file
142
client/internal/peer/notifier.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
stateDisconnected = iota
|
||||||
|
stateConnected
|
||||||
|
stateConnecting
|
||||||
|
stateDisconnecting
|
||||||
|
)
|
||||||
|
|
||||||
|
type notifier struct {
|
||||||
|
serverStateLock sync.Mutex
|
||||||
|
listenersLock sync.Mutex
|
||||||
|
listener Listener
|
||||||
|
currentClientState bool
|
||||||
|
lastNotification int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNotifier() *notifier {
|
||||||
|
return ¬ifier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) setListener(listener Listener) {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
n.notifyListener(listener, n.lastNotification)
|
||||||
|
n.serverStateLock.Unlock()
|
||||||
|
|
||||||
|
n.listener = listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) removeListener() {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
n.listener = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
defer n.serverStateLock.Unlock()
|
||||||
|
|
||||||
|
calculatedState := n.calculateState(mgmState, signalState)
|
||||||
|
|
||||||
|
if !n.isServerStateChanged(calculatedState) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.lastNotification = calculatedState
|
||||||
|
|
||||||
|
n.notify(n.lastNotification)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) clientStart() {
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
defer n.serverStateLock.Unlock()
|
||||||
|
n.currentClientState = true
|
||||||
|
n.lastNotification = stateConnected
|
||||||
|
n.notify(n.lastNotification)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) clientStop() {
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
defer n.serverStateLock.Unlock()
|
||||||
|
n.currentClientState = false
|
||||||
|
n.lastNotification = stateDisconnected
|
||||||
|
n.notify(n.lastNotification)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) clientTearDown() {
|
||||||
|
n.serverStateLock.Lock()
|
||||||
|
defer n.serverStateLock.Unlock()
|
||||||
|
n.currentClientState = false
|
||||||
|
n.lastNotification = stateDisconnecting
|
||||||
|
n.notify(n.lastNotification)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) isServerStateChanged(newState int) bool {
|
||||||
|
return n.lastNotification != newState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) notify(state int) {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
if n.listener == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.notifyListener(n.listener, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) notifyListener(l Listener, state int) {
|
||||||
|
go func() {
|
||||||
|
switch state {
|
||||||
|
case stateDisconnected:
|
||||||
|
l.OnDisconnected()
|
||||||
|
case stateConnected:
|
||||||
|
l.OnConnected()
|
||||||
|
case stateConnecting:
|
||||||
|
l.OnConnecting()
|
||||||
|
case stateDisconnecting:
|
||||||
|
l.OnDisconnecting()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) calculateState(managementConn, signalConn bool) int {
|
||||||
|
if managementConn && signalConn {
|
||||||
|
return stateConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
if !managementConn && !signalConn {
|
||||||
|
return stateDisconnected
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.lastNotification == stateDisconnecting {
|
||||||
|
return stateDisconnecting
|
||||||
|
}
|
||||||
|
|
||||||
|
return stateConnecting
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) peerListChanged(numOfPeers int) {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
if n.listener == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.listener.OnPeersListChanged(numOfPeers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *notifier) localAddressChanged(fqdn, address string) {
|
||||||
|
n.listenersLock.Lock()
|
||||||
|
defer n.listenersLock.Unlock()
|
||||||
|
if n.listener == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.listener.OnAddressChanged(fqdn, address)
|
||||||
|
}
|
||||||
97
client/internal/peer/notifier_test.go
Normal file
97
client/internal/peer/notifier_test.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mocListener struct {
|
||||||
|
lastState int
|
||||||
|
wg sync.WaitGroup
|
||||||
|
peers int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *mocListener) OnConnected() {
|
||||||
|
l.lastState = stateConnected
|
||||||
|
l.wg.Done()
|
||||||
|
}
|
||||||
|
func (l *mocListener) OnDisconnected() {
|
||||||
|
l.lastState = stateDisconnected
|
||||||
|
l.wg.Done()
|
||||||
|
}
|
||||||
|
func (l *mocListener) OnConnecting() {
|
||||||
|
l.lastState = stateConnecting
|
||||||
|
l.wg.Done()
|
||||||
|
}
|
||||||
|
func (l *mocListener) OnDisconnecting() {
|
||||||
|
l.lastState = stateDisconnecting
|
||||||
|
l.wg.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *mocListener) OnAddressChanged(host, addr string) {
|
||||||
|
|
||||||
|
}
|
||||||
|
func (l *mocListener) OnPeersListChanged(size int) {
|
||||||
|
l.peers = size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *mocListener) setWaiter() {
|
||||||
|
l.wg.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *mocListener) wait() {
|
||||||
|
l.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_notifier_serverState(t *testing.T) {
|
||||||
|
|
||||||
|
type scenario struct {
|
||||||
|
name string
|
||||||
|
expected int
|
||||||
|
mgmState bool
|
||||||
|
signalState bool
|
||||||
|
}
|
||||||
|
scenarios := []scenario{
|
||||||
|
{"connected", stateConnected, true, true},
|
||||||
|
{"mgm down", stateConnecting, false, true},
|
||||||
|
{"signal down", stateConnecting, true, false},
|
||||||
|
{"disconnected", stateDisconnected, false, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range scenarios {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
n := newNotifier()
|
||||||
|
n.updateServerStates(tt.mgmState, tt.signalState)
|
||||||
|
if n.lastNotification != tt.expected {
|
||||||
|
t.Errorf("invalid serverstate: %d, expected: %d", n.lastNotification, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_notifier_SetListener(t *testing.T) {
|
||||||
|
listener := &mocListener{}
|
||||||
|
listener.setWaiter()
|
||||||
|
|
||||||
|
n := newNotifier()
|
||||||
|
n.lastNotification = stateConnecting
|
||||||
|
n.setListener(listener)
|
||||||
|
listener.wait()
|
||||||
|
if listener.lastState != n.lastNotification {
|
||||||
|
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_notifier_RemoveListener(t *testing.T) {
|
||||||
|
listener := &mocListener{}
|
||||||
|
listener.setWaiter()
|
||||||
|
n := newNotifier()
|
||||||
|
n.lastNotification = stateConnecting
|
||||||
|
n.setListener(listener)
|
||||||
|
n.removeListener()
|
||||||
|
n.peerListChanged(1)
|
||||||
|
|
||||||
|
if listener.peers != 0 {
|
||||||
|
t.Errorf("invalid state: %d", listener.peers)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,25 +1,316 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import log "github.com/sirupsen/logrus"
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type ConnStatus int
|
// State contains the latest state of a peer
|
||||||
|
type State struct {
|
||||||
|
IP string
|
||||||
|
PubKey string
|
||||||
|
FQDN string
|
||||||
|
ConnStatus ConnStatus
|
||||||
|
ConnStatusUpdate time.Time
|
||||||
|
Relayed bool
|
||||||
|
Direct bool
|
||||||
|
LocalIceCandidateType string
|
||||||
|
RemoteIceCandidateType string
|
||||||
|
}
|
||||||
|
|
||||||
func (s ConnStatus) String() string {
|
// LocalPeerState contains the latest state of the local peer
|
||||||
switch s {
|
type LocalPeerState struct {
|
||||||
case StatusConnecting:
|
IP string
|
||||||
return "Connecting"
|
PubKey string
|
||||||
case StatusConnected:
|
KernelInterface bool
|
||||||
return "Connected"
|
FQDN string
|
||||||
case StatusDisconnected:
|
}
|
||||||
return "Disconnected"
|
|
||||||
default:
|
// SignalState contains the latest state of a signal connection
|
||||||
log.Errorf("unknown status: %d", s)
|
type SignalState struct {
|
||||||
return "INVALID_PEER_CONNECTION_STATUS"
|
URL string
|
||||||
|
Connected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagementState contains the latest state of a management connection
|
||||||
|
type ManagementState struct {
|
||||||
|
URL string
|
||||||
|
Connected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// FullStatus contains the full state held by the Status instance
|
||||||
|
type FullStatus struct {
|
||||||
|
Peers []State
|
||||||
|
ManagementState ManagementState
|
||||||
|
SignalState SignalState
|
||||||
|
LocalPeerState LocalPeerState
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status holds a state of peers, signal and management connections
|
||||||
|
type Status struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
peers map[string]State
|
||||||
|
changeNotify map[string]chan struct{}
|
||||||
|
signalState bool
|
||||||
|
managementState bool
|
||||||
|
localPeer LocalPeerState
|
||||||
|
offlinePeers []State
|
||||||
|
mgmAddress string
|
||||||
|
signalAddress string
|
||||||
|
notifier *notifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRecorder returns a new Status instance
|
||||||
|
func NewRecorder(mgmAddress string) *Status {
|
||||||
|
return &Status{
|
||||||
|
peers: make(map[string]State),
|
||||||
|
changeNotify: make(map[string]chan struct{}),
|
||||||
|
offlinePeers: make([]State, 0),
|
||||||
|
notifier: newNotifier(),
|
||||||
|
mgmAddress: mgmAddress,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
// ReplaceOfflinePeers replaces
|
||||||
StatusConnected ConnStatus = iota
|
func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||||
StatusConnecting
|
d.mux.Lock()
|
||||||
StatusDisconnected
|
defer d.mux.Unlock()
|
||||||
)
|
d.offlinePeers = make([]State, len(replacement))
|
||||||
|
copy(d.offlinePeers, replacement)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeer adds peer to Daemon status map
|
||||||
|
func (d *Status) AddPeer(peerPubKey string) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
_, ok := d.peers[peerPubKey]
|
||||||
|
if ok {
|
||||||
|
return errors.New("peer already exist")
|
||||||
|
}
|
||||||
|
d.peers[peerPubKey] = State{PubKey: peerPubKey, ConnStatus: StatusDisconnected}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeer adds peer to Daemon status map
|
||||||
|
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
state, ok := d.peers[peerPubKey]
|
||||||
|
if !ok {
|
||||||
|
return State{}, errors.New("peer not found")
|
||||||
|
}
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer removes peer from Daemon status map
|
||||||
|
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
_, ok := d.peers[peerPubKey]
|
||||||
|
if ok {
|
||||||
|
delete(d.peers, peerPubKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d.notifyPeerListChanged()
|
||||||
|
return errors.New("no peer with to remove")
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerState updates peer status
|
||||||
|
func (d *Status) UpdatePeerState(receivedState State) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[receivedState.PubKey]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
if receivedState.IP != "" {
|
||||||
|
peerState.IP = receivedState.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||||
|
peerState.ConnStatus = receivedState.ConnStatus
|
||||||
|
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
|
||||||
|
peerState.Direct = receivedState.Direct
|
||||||
|
peerState.Relayed = receivedState.Relayed
|
||||||
|
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
|
||||||
|
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
|
||||||
|
}
|
||||||
|
|
||||||
|
d.peers[receivedState.PubKey] = peerState
|
||||||
|
|
||||||
|
ch, found := d.changeNotify[receivedState.PubKey]
|
||||||
|
if found && ch != nil {
|
||||||
|
close(ch)
|
||||||
|
d.changeNotify[receivedState.PubKey] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d.notifyPeerListChanged()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerFQDN update peer's state fqdn only
|
||||||
|
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[peerPubKey]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState.FQDN = fqdn
|
||||||
|
d.peers[peerPubKey] = peerState
|
||||||
|
|
||||||
|
d.notifyPeerListChanged()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
||||||
|
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
ch, found := d.changeNotify[peer]
|
||||||
|
if !found || ch == nil {
|
||||||
|
ch = make(chan struct{})
|
||||||
|
d.changeNotify[peer] = ch
|
||||||
|
}
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLocalPeerState updates local peer status
|
||||||
|
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
d.localPeer = localPeerState
|
||||||
|
d.notifyAddressChanged()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanLocalPeerState cleans local peer status
|
||||||
|
func (d *Status) CleanLocalPeerState() {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
d.localPeer = LocalPeerState{}
|
||||||
|
d.notifyAddressChanged()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||||
|
func (d *Status) MarkManagementDisconnected() {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
|
d.managementState = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkManagementConnected sets ManagementState to connected
|
||||||
|
func (d *Status) MarkManagementConnected() {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
|
d.managementState = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSignalAddress update the address of the signal server
|
||||||
|
func (d *Status) UpdateSignalAddress(signalURL string) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.signalAddress = signalURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateManagementAddress update the address of the management server
|
||||||
|
func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.mgmAddress = mgmAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkSignalDisconnected sets SignalState to disconnected
|
||||||
|
func (d *Status) MarkSignalDisconnected() {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
|
d.signalState = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkSignalConnected sets SignalState to connected
|
||||||
|
func (d *Status) MarkSignalConnected() {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
|
d.signalState = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFullStatus gets full status
|
||||||
|
func (d *Status) GetFullStatus() FullStatus {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
fullStatus := FullStatus{
|
||||||
|
ManagementState: ManagementState{
|
||||||
|
d.mgmAddress,
|
||||||
|
d.managementState,
|
||||||
|
},
|
||||||
|
SignalState: SignalState{
|
||||||
|
d.signalAddress,
|
||||||
|
d.signalState,
|
||||||
|
},
|
||||||
|
LocalPeerState: d.localPeer,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, status := range d.peers {
|
||||||
|
fullStatus.Peers = append(fullStatus.Peers, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
|
||||||
|
|
||||||
|
return fullStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientStart will notify all listeners about the new service state
|
||||||
|
func (d *Status) ClientStart() {
|
||||||
|
d.notifier.clientStart()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientStop will notify all listeners about the new service state
|
||||||
|
func (d *Status) ClientStop() {
|
||||||
|
d.notifier.clientStop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientTeardown will notify all listeners about the service is under teardown
|
||||||
|
func (d *Status) ClientTeardown() {
|
||||||
|
d.notifier.clientTearDown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConnectionListener set a listener to the notifier
|
||||||
|
func (d *Status) SetConnectionListener(listener Listener) {
|
||||||
|
d.notifier.setListener(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveConnectionListener remove the listener from the notifier
|
||||||
|
func (d *Status) RemoveConnectionListener() {
|
||||||
|
d.notifier.removeListener()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) onConnectionChanged() {
|
||||||
|
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) notifyPeerListChanged() {
|
||||||
|
d.notifier.peerListChanged(len(d.peers))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) notifyAddressChanged() {
|
||||||
|
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,27 +1,233 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/magiconair/properties/assert"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConnStatus_String(t *testing.T) {
|
func TestAddPeer(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
err := status.AddPeer(key)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
tables := []struct {
|
_, exists := status.peers[key]
|
||||||
name string
|
assert.True(t, exists, "value was found")
|
||||||
status ConnStatus
|
|
||||||
want string
|
err = status.AddPeer(key)
|
||||||
}{
|
|
||||||
{"StatusConnected", StatusConnected, "Connected"},
|
assert.Error(t, err, "should return error on duplicate")
|
||||||
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
|
}
|
||||||
{"StatusConnecting", StatusConnecting, "Connecting"},
|
|
||||||
|
func TestGetPeer(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
err := status.AddPeer(key)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
peerStatus, err := status.GetPeer(key)
|
||||||
|
assert.NoError(t, err, "shouldn't return error on getting peer")
|
||||||
|
|
||||||
|
assert.Equal(t, key, peerStatus.PubKey, "retrieved public key should match")
|
||||||
|
|
||||||
|
_, err = status.GetPeer("non_existing_key")
|
||||||
|
assert.Error(t, err, "should return error when peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatePeerState(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
ip := "10.10.10.10"
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
peerState := State{
|
||||||
|
PubKey: key,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
status.peers[key] = peerState
|
||||||
t.Run(table.name, func(t *testing.T) {
|
|
||||||
got := table.status.String()
|
peerState.IP = ip
|
||||||
assert.Equal(t, got, table.want, "they should be equal")
|
|
||||||
|
err := status.UpdatePeerState(peerState)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
state, exists := status.peers[key]
|
||||||
|
assert.True(t, exists, "state should be found")
|
||||||
|
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
fqdn := "peer-a.netbird.local"
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
peerState := State{
|
||||||
|
PubKey: key,
|
||||||
|
}
|
||||||
|
|
||||||
|
status.peers[key] = peerState
|
||||||
|
|
||||||
|
err := status.UpdatePeerFQDN(key, fqdn)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
state, exists := status.peers[key]
|
||||||
|
assert.True(t, exists, "state should be found")
|
||||||
|
assert.Equal(t, fqdn, state.FQDN, "fqdn should be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
ip := "10.10.10.10"
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
peerState := State{
|
||||||
|
PubKey: key,
|
||||||
|
}
|
||||||
|
|
||||||
|
status.peers[key] = peerState
|
||||||
|
|
||||||
|
ch := status.GetPeerStateChangeNotifier(key)
|
||||||
|
assert.NotNil(t, ch, "channel shouldn't be nil")
|
||||||
|
|
||||||
|
peerState.IP = ip
|
||||||
|
|
||||||
|
err := status.UpdatePeerState(peerState)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
default:
|
||||||
|
t.Errorf("channel wasn't closed after update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemovePeer(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
peerState := State{
|
||||||
|
PubKey: key,
|
||||||
|
}
|
||||||
|
|
||||||
|
status.peers[key] = peerState
|
||||||
|
|
||||||
|
err := status.RemovePeer(key)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
_, exists := status.peers[key]
|
||||||
|
assert.False(t, exists, "state value shouldn't be found")
|
||||||
|
|
||||||
|
err = status.RemovePeer("not existing")
|
||||||
|
assert.Error(t, err, "should return error when peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateLocalPeerState(t *testing.T) {
|
||||||
|
localPeerState := LocalPeerState{
|
||||||
|
IP: "10.10.10.10",
|
||||||
|
PubKey: "abc",
|
||||||
|
KernelInterface: false,
|
||||||
|
}
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
|
||||||
|
status.UpdateLocalPeerState(localPeerState)
|
||||||
|
|
||||||
|
assert.Equal(t, localPeerState, status.localPeer, "local peer status should be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanLocalPeerState(t *testing.T) {
|
||||||
|
emptyLocalPeerState := LocalPeerState{}
|
||||||
|
localPeerState := LocalPeerState{
|
||||||
|
IP: "10.10.10.10",
|
||||||
|
PubKey: "abc",
|
||||||
|
KernelInterface: false,
|
||||||
|
}
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
|
||||||
|
status.localPeer = localPeerState
|
||||||
|
|
||||||
|
status.CleanLocalPeerState()
|
||||||
|
|
||||||
|
assert.Equal(t, emptyLocalPeerState, status.localPeer, "local peer status should be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateSignalState(t *testing.T) {
|
||||||
|
url := "https://signal"
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
connected bool
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"should mark as connected", true, true},
|
||||||
|
{"should mark as disconnected", false, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
status.UpdateSignalAddress(url)
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
if test.connected {
|
||||||
|
status.MarkSignalConnected()
|
||||||
|
} else {
|
||||||
|
status.MarkSignalDisconnected()
|
||||||
|
}
|
||||||
|
assert.Equal(t, test.want, status.signalState, "signal status should be equal")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateManagementState(t *testing.T) {
|
||||||
|
url := "https://management"
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
connected bool
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"should mark as connected", true, true},
|
||||||
|
{"should mark as disconnected", false, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
status := NewRecorder(url)
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
if test.connected {
|
||||||
|
status.MarkManagementConnected()
|
||||||
|
} else {
|
||||||
|
status.MarkManagementDisconnected()
|
||||||
|
}
|
||||||
|
assert.Equal(t, test.want, status.managementState, "signalState status should be equal")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFullStatus(t *testing.T) {
|
||||||
|
key1 := "abc"
|
||||||
|
key2 := "def"
|
||||||
|
signalAddr := "https://signal"
|
||||||
|
managementState := ManagementState{
|
||||||
|
URL: "https://mgm",
|
||||||
|
Connected: true,
|
||||||
|
}
|
||||||
|
signalState := SignalState{
|
||||||
|
URL: signalAddr,
|
||||||
|
Connected: true,
|
||||||
|
}
|
||||||
|
peerState1 := State{
|
||||||
|
PubKey: key1,
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState2 := State{
|
||||||
|
PubKey: key2,
|
||||||
|
}
|
||||||
|
|
||||||
|
status := NewRecorder("https://mgm")
|
||||||
|
status.UpdateSignalAddress(signalAddr)
|
||||||
|
|
||||||
|
status.managementState = managementState.Connected
|
||||||
|
status.signalState = signalState.Connected
|
||||||
|
status.peers[key1] = peerState1
|
||||||
|
status.peers[key2] = peerState2
|
||||||
|
|
||||||
|
fullStatus := status.GetFullStatus()
|
||||||
|
|
||||||
|
assert.Equal(t, managementState, fullStatus.ManagementState, "management status should be equal")
|
||||||
|
assert.Equal(t, signalState, fullStatus.SignalState, "signal status should be equal")
|
||||||
|
assert.ElementsMatch(t, []State{peerState1, peerState2}, fullStatus.Peers, "peers states should match")
|
||||||
}
|
}
|
||||||
|
|||||||
11
client/internal/peer/stdnet.go
Normal file
11
client/internal/peer/stdnet.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNet(conn.config.InterfaceBlackList)
|
||||||
|
}
|
||||||
7
client/internal/peer/stdnet_android.go
Normal file
7
client/internal/peer/stdnet_android.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
|
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
|
||||||
|
}
|
||||||
67
client/internal/proxy/direct.go
Normal file
67
client/internal/proxy/direct.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
|
||||||
|
// This is possible in either of these cases:
|
||||||
|
// - peers are in the same local network
|
||||||
|
// - one of the peers has a public static IP (host)
|
||||||
|
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
|
||||||
|
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
|
||||||
|
type DirectNoProxy struct {
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
|
||||||
|
remoteKey string
|
||||||
|
allowedIps string
|
||||||
|
|
||||||
|
// RemoteWgListenPort is a WireGuard port of a remote peer.
|
||||||
|
// It is used instead of the hardcoded 51820 port.
|
||||||
|
remoteWgListenPort int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
|
||||||
|
func NewDirectNoProxy(wgInterface *iface.WGIface, remoteKey string, allowedIps string, remoteWgPort int) *DirectNoProxy {
|
||||||
|
return &DirectNoProxy{
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
remoteKey: remoteKey,
|
||||||
|
allowedIps: allowedIps,
|
||||||
|
remoteWgListenPort: remoteWgPort}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close removes peer from the WireGuard interface
|
||||||
|
func (p *DirectNoProxy) Close() error {
|
||||||
|
err := p.wgInterface.RemovePeer(p.remoteKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start just updates WireGuard peer with the remote IP and default WireGuard port
|
||||||
|
func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
|
||||||
|
|
||||||
|
log.Debugf("using DirectNoProxy while connecting to peer %s", p.remoteKey)
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
addr.Port = p.remoteWgListenPort
|
||||||
|
err = p.wgInterface.UpdatePeer(p.remoteKey, p.allowedIps, addr)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the type of this proxy
|
||||||
|
func (p *DirectNoProxy) Type() Type {
|
||||||
|
return TypeDirectNoProxy
|
||||||
|
}
|
||||||
@@ -5,24 +5,18 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NoProxy is used when there is no need for a proxy between ICE and Wireguard.
|
// NoProxy is used just to configure WireGuard without any local proxy in between.
|
||||||
// This is possible in either of these cases:
|
// Used when the WireGuard interface is userspace and uses bind.ICEBind
|
||||||
// - peers are in the same local network
|
|
||||||
// - one of the peers has a public static IP (host)
|
|
||||||
// NoProxy will just update remote peer with a remote host and fixed Wireguard port (r.g. 51820).
|
|
||||||
// In order NoProxy to work, Wireguard port has to be fixed for the time being.
|
|
||||||
type NoProxy struct {
|
type NoProxy struct {
|
||||||
config Config
|
config Config
|
||||||
// RemoteWgListenPort is a WireGuard port of a remote peer.
|
|
||||||
// It is used instead of the hardcoded 51820 port.
|
|
||||||
RemoteWgListenPort int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNoProxy creates a new NoProxy with a provided config and remote peer's WireGuard listen port
|
// NewNoProxy creates a new NoProxy with a provided config
|
||||||
func NewNoProxy(config Config, remoteWgPort int) *NoProxy {
|
func NewNoProxy(config Config) *NoProxy {
|
||||||
return &NoProxy{config: config, RemoteWgListenPort: remoteWgPort}
|
return &NoProxy{config: config}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close removes peer from the WireGuard interface
|
||||||
func (p *NoProxy) Close() error {
|
func (p *NoProxy) Close() error {
|
||||||
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -31,23 +25,16 @@ func (p *NoProxy) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start just updates Wireguard peer with the remote IP and default Wireguard port
|
// Start just updates WireGuard peer with the remote address
|
||||||
func (p *NoProxy) Start(remoteConn net.Conn) error {
|
func (p *NoProxy) Start(remoteConn net.Conn) error {
|
||||||
|
|
||||||
log.Debugf("using NoProxy while connecting to peer %s", p.config.RemoteKey)
|
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
|
||||||
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
addr.Port = p.RemoteWgListenPort
|
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
||||||
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
|
|
||||||
addr, p.config.PreSharedKey)
|
addr, p.config.PreSharedKey)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *NoProxy) Type() Type {
|
func (p *NoProxy) Type() Type {
|
||||||
|
|||||||
@@ -1,31 +1,19 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultWgKeepAlive = 25 * time.Second
|
|
||||||
|
|
||||||
type Type string
|
type Type string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TypeNoProxy Type = "NoProxy"
|
TypeDirectNoProxy Type = "DirectNoProxy"
|
||||||
TypeWireguard Type = "Wireguard"
|
TypeWireGuard Type = "WireGuard"
|
||||||
TypeDummy Type = "Dummy"
|
TypeDummy Type = "Dummy"
|
||||||
|
TypeNoProxy Type = "NoProxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
WgListenAddr string
|
|
||||||
RemoteKey string
|
|
||||||
WgInterface *iface.WGIface
|
|
||||||
AllowedIps string
|
|
||||||
PreSharedKey *wgtypes.Key
|
|
||||||
}
|
|
||||||
|
|
||||||
type Proxy interface {
|
type Proxy interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
// Start creates a local remoteConn and starts proxying data from/to remoteConn
|
// Start creates a local remoteConn and starts proxying data from/to remoteConn
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WireguardProxy proxies
|
// WireGuardProxy proxies
|
||||||
type WireguardProxy struct {
|
type WireGuardProxy struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
@@ -17,13 +17,13 @@ type WireguardProxy struct {
|
|||||||
localConn net.Conn
|
localConn net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireguardProxy(config Config) *WireguardProxy {
|
func NewWireGuardProxy(config Config) *WireGuardProxy {
|
||||||
p := &WireguardProxy{config: config}
|
p := &WireGuardProxy{config: config}
|
||||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireguardProxy) updateEndpoint() error {
|
func (p *WireGuardProxy) updateEndpoint() error {
|
||||||
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -38,7 +38,7 @@ func (p *WireguardProxy) updateEndpoint() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireguardProxy) Start(remoteConn net.Conn) error {
|
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
@@ -60,7 +60,7 @@ func (p *WireguardProxy) Start(remoteConn net.Conn) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireguardProxy) Close() error {
|
func (p *WireGuardProxy) Close() error {
|
||||||
p.cancel()
|
p.cancel()
|
||||||
if c := p.localConn; c != nil {
|
if c := p.localConn; c != nil {
|
||||||
err := p.localConn.Close()
|
err := p.localConn.Close()
|
||||||
@@ -77,7 +77,7 @@ func (p *WireguardProxy) Close() error {
|
|||||||
|
|
||||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||||
// blocks
|
// blocks
|
||||||
func (p *WireguardProxy) proxyToRemote() {
|
func (p *WireGuardProxy) proxyToRemote() {
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
@@ -101,7 +101,7 @@ func (p *WireguardProxy) proxyToRemote() {
|
|||||||
|
|
||||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||||
// blocks
|
// blocks
|
||||||
func (p *WireguardProxy) proxyToLocal() {
|
func (p *WireGuardProxy) proxyToLocal() {
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
@@ -123,6 +123,6 @@ func (p *WireguardProxy) proxyToLocal() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WireguardProxy) Type() Type {
|
func (p *WireGuardProxy) Type() Type {
|
||||||
return TypeWireguard
|
return TypeWireGuard
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,11 +5,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type routerPeerStatus struct {
|
type routerPeerStatus struct {
|
||||||
@@ -26,7 +26,7 @@ type routesUpdate struct {
|
|||||||
type clientNetwork struct {
|
type clientNetwork struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
statusRecorder *status.Status
|
statusRecorder *peer.Status
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
routes map[string]*route.Route
|
routes map[string]*route.Route
|
||||||
routeUpdate chan routesUpdate
|
routeUpdate chan routesUpdate
|
||||||
@@ -37,7 +37,7 @@ type clientNetwork struct {
|
|||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *status.Status, network netip.Prefix) *clientNetwork {
|
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -62,7 +62,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
routePeerStatuses[r.ID] = routerPeerStatus{
|
routePeerStatuses[r.ID] = routerPeerStatus{
|
||||||
connected: peerStatus.ConnStatus == peer.StatusConnected.String(),
|
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
||||||
relayed: peerStatus.Relayed,
|
relayed: peerStatus.Relayed,
|
||||||
direct: peerStatus.Direct,
|
direct: peerStatus.Direct,
|
||||||
}
|
}
|
||||||
@@ -123,7 +123,7 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
|||||||
return
|
return
|
||||||
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
|
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
|
||||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||||
if err != nil || state.ConnStatus == peer.StatusConnecting.String() {
|
if err != nil || state.ConnStatus == peer.StatusConnecting {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
peerStateUpdate <- struct{}{}
|
peerStateUpdate <- struct{}{}
|
||||||
@@ -144,7 +144,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
|||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||||
if err != nil || state.ConnStatus != peer.StatusConnected.String() {
|
if err != nil || state.ConnStatus != peer.StatusConnected {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,7 +162,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.GetAddress().IP.String())
|
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("couldn't remove route %s from system, err: %v",
|
return fmt.Errorf("couldn't remove route %s from system, err: %v",
|
||||||
c.network, err)
|
c.network, err)
|
||||||
@@ -201,10 +201,10 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String())
|
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||||
c.network.String(), c.wgInterface.GetAddress().IP.String(), err)
|
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,189 +1,9 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import "github.com/netbirdio/netbird/route"
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
Stop()
|
Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultManager is the default instance of a route manager
|
|
||||||
type DefaultManager struct {
|
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
mux sync.Mutex
|
|
||||||
clientNetworks map[string]*clientNetwork
|
|
||||||
serverRoutes map[string]*route.Route
|
|
||||||
serverRouter *serverRouter
|
|
||||||
statusRecorder *status.Status
|
|
||||||
wgInterface *iface.WGIface
|
|
||||||
pubKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManager returns a new route manager
|
|
||||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *DefaultManager {
|
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
|
||||||
return &DefaultManager{
|
|
||||||
ctx: mCTX,
|
|
||||||
stop: cancel,
|
|
||||||
clientNetworks: make(map[string]*clientNetwork),
|
|
||||||
serverRoutes: make(map[string]*route.Route),
|
|
||||||
serverRouter: &serverRouter{
|
|
||||||
routes: make(map[string]*route.Route),
|
|
||||||
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
|
||||||
firewall: NewFirewall(ctx),
|
|
||||||
},
|
|
||||||
statusRecorder: statusRecorder,
|
|
||||||
wgInterface: wgInterface,
|
|
||||||
pubKey: pubKey,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the manager watchers and clean firewall rules
|
|
||||||
func (m *DefaultManager) Stop() {
|
|
||||||
m.stop()
|
|
||||||
m.serverRouter.firewall.CleanRoutingRules()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
|
||||||
// removing routes that do not exist as per the update from the Management service.
|
|
||||||
for id, client := range m.clientNetworks {
|
|
||||||
_, found := networks[id]
|
|
||||||
if !found {
|
|
||||||
log.Debugf("stopping client network watcher, %s", id)
|
|
||||||
client.stop()
|
|
||||||
delete(m.clientNetworks, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, routes := range networks {
|
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
|
||||||
if !found {
|
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
|
||||||
}
|
|
||||||
update := routesUpdate{
|
|
||||||
updateSerial: updateSerial,
|
|
||||||
routes: routes,
|
|
||||||
}
|
|
||||||
|
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
|
||||||
serverRoutesToRemove := make([]string, 0)
|
|
||||||
|
|
||||||
if len(routesMap) > 0 {
|
|
||||||
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for routeID := range m.serverRoutes {
|
|
||||||
update, found := routesMap[routeID]
|
|
||||||
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
|
||||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, routeID := range serverRoutesToRemove {
|
|
||||||
oldRoute := m.serverRoutes[routeID]
|
|
||||||
err := m.removeFromServerNetwork(oldRoute)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
|
||||||
oldRoute.ID, oldRoute.Network, err)
|
|
||||||
}
|
|
||||||
delete(m.serverRoutes, routeID)
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, newRoute := range routesMap {
|
|
||||||
_, found := m.serverRoutes[id]
|
|
||||||
if found {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := m.addToServerNetwork(newRoute)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m.serverRoutes[id] = newRoute
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.serverRoutes) > 0 {
|
|
||||||
err := enableIPForwarding()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
|
||||||
select {
|
|
||||||
case <-m.ctx.Done():
|
|
||||||
log.Infof("not updating routes as context is closed")
|
|
||||||
return m.ctx.Err()
|
|
||||||
default:
|
|
||||||
m.mux.Lock()
|
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
|
||||||
newServerRoutesMap := make(map[string]*route.Route)
|
|
||||||
ownNetworkIDs := make(map[string]bool)
|
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
|
||||||
if newRoute.Peer == m.pubKey {
|
|
||||||
ownNetworkIDs[networkID] = true
|
|
||||||
// only linux is supported for now
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newServerRoutesMap[newRoute.ID] = newRoute
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
|
||||||
if !ownNetworkIDs[networkID] {
|
|
||||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
|
||||||
// we skip this route management
|
|
||||||
if newRoute.Network.Bits() < 7 {
|
|
||||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
|
||||||
system.NetbirdVersion(), newRoute.Network)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
|
||||||
|
|
||||||
err := m.updateServerRoutes(newServerRoutesMap)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
31
client/internal/routemanager/manager_android.go
Normal file
31
client/internal/routemanager/manager_android.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultManager dummy router manager for Android
|
||||||
|
type DefaultManager struct {
|
||||||
|
ctx context.Context
|
||||||
|
serverRouter *serverRouter
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager returns a new dummy route manager what doing nothing
|
||||||
|
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||||
|
return &DefaultManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRoutes ...
|
||||||
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop ...
|
||||||
|
func (m *DefaultManager) Stop() {
|
||||||
|
|
||||||
|
}
|
||||||
186
client/internal/routemanager/manager_nonandroid.go
Normal file
186
client/internal/routemanager/manager_nonandroid.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultManager is the default instance of a route manager
|
||||||
|
type DefaultManager struct {
|
||||||
|
ctx context.Context
|
||||||
|
stop context.CancelFunc
|
||||||
|
mux sync.Mutex
|
||||||
|
clientNetworks map[string]*clientNetwork
|
||||||
|
serverRoutes map[string]*route.Route
|
||||||
|
serverRouter *serverRouter
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
pubKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager returns a new route manager
|
||||||
|
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||||
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
|
return &DefaultManager{
|
||||||
|
ctx: mCTX,
|
||||||
|
stop: cancel,
|
||||||
|
clientNetworks: make(map[string]*clientNetwork),
|
||||||
|
serverRoutes: make(map[string]*route.Route),
|
||||||
|
serverRouter: &serverRouter{
|
||||||
|
routes: make(map[string]*route.Route),
|
||||||
|
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
|
||||||
|
firewall: NewFirewall(ctx),
|
||||||
|
},
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
pubKey: pubKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the manager watchers and clean firewall rules
|
||||||
|
func (m *DefaultManager) Stop() {
|
||||||
|
m.stop()
|
||||||
|
m.serverRouter.firewall.CleanRoutingRules()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||||
|
// removing routes that do not exist as per the update from the Management service.
|
||||||
|
for id, client := range m.clientNetworks {
|
||||||
|
_, found := networks[id]
|
||||||
|
if !found {
|
||||||
|
log.Debugf("stopping client network watcher, %s", id)
|
||||||
|
client.stop()
|
||||||
|
delete(m.clientNetworks, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, routes := range networks {
|
||||||
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
|
if !found {
|
||||||
|
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||||
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
|
}
|
||||||
|
update := routesUpdate{
|
||||||
|
updateSerial: updateSerial,
|
||||||
|
routes: routes,
|
||||||
|
}
|
||||||
|
|
||||||
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
|
||||||
|
serverRoutesToRemove := make([]string, 0)
|
||||||
|
|
||||||
|
if len(routesMap) > 0 {
|
||||||
|
err := m.serverRouter.firewall.RestoreOrCreateContainers()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for routeID := range m.serverRoutes {
|
||||||
|
update, found := routesMap[routeID]
|
||||||
|
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
|
||||||
|
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, routeID := range serverRoutesToRemove {
|
||||||
|
oldRoute := m.serverRoutes[routeID]
|
||||||
|
err := m.removeFromServerNetwork(oldRoute)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||||
|
oldRoute.ID, oldRoute.Network, err)
|
||||||
|
}
|
||||||
|
delete(m.serverRoutes, routeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, newRoute := range routesMap {
|
||||||
|
_, found := m.serverRoutes[id]
|
||||||
|
if found {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.addToServerNetwork(newRoute)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.serverRoutes[id] = newRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.serverRoutes) > 0 {
|
||||||
|
err := enableIPForwarding()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||||
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
log.Infof("not updating routes as context is closed")
|
||||||
|
return m.ctx.Err()
|
||||||
|
default:
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||||
|
newServerRoutesMap := make(map[string]*route.Route)
|
||||||
|
ownNetworkIDs := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, newRoute := range newRoutes {
|
||||||
|
networkID := route.GetHAUniqueID(newRoute)
|
||||||
|
if newRoute.Peer == m.pubKey {
|
||||||
|
ownNetworkIDs[networkID] = true
|
||||||
|
// only linux is supported for now
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newServerRoutesMap[newRoute.ID] = newRoute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, newRoute := range newRoutes {
|
||||||
|
networkID := route.GetHAUniqueID(newRoute)
|
||||||
|
if !ownNetworkIDs[networkID] {
|
||||||
|
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||||
|
// we skip this route management
|
||||||
|
if newRoute.Network.Bits() < 7 {
|
||||||
|
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
|
||||||
|
version.NetbirdVersion(), newRoute.Network)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||||
|
|
||||||
|
err := m.updateServerRoutes(newServerRoutesMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,14 +3,16 @@ package routemanager
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/status"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// send 5 routes, one for server and 4 for clients, one normal and 2 HA and one small
|
// send 5 routes, one for server and 4 for clients, one normal and 2 HA and one small
|
||||||
@@ -390,14 +392,18 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU)
|
newNet, err := stdnet.NewNet()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet)
|
||||||
require.NoError(t, err, "should create testing WGIface interface")
|
require.NoError(t, err, "should create testing WGIface interface")
|
||||||
defer wgInterface.Close()
|
defer wgInterface.Close()
|
||||||
|
|
||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
|
||||||
statusRecorder := status.NewRecorder()
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
|
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
|
||||||
defer routeManager.Stop()
|
defer routeManager.Stop()
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
|
|||||||
default:
|
default:
|
||||||
m.serverRouter.mux.Lock()
|
m.serverRouter.mux.Lock()
|
||||||
defer m.serverRouter.mux.Unlock()
|
defer m.serverRouter.mux.Unlock()
|
||||||
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route))
|
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -57,7 +57,7 @@ func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
|
|||||||
default:
|
default:
|
||||||
m.serverRouter.mux.Lock()
|
m.serverRouter.mux.Lock()
|
||||||
defer m.serverRouter.mux.Unlock()
|
defer m.serverRouter.mux.Unlock()
|
||||||
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route))
|
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package routemanager
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -32,25 +33,29 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
|
|
||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU)
|
newNet, err := stdnet.NewNet()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet)
|
||||||
require.NoError(t, err, "should create testing WGIface interface")
|
require.NoError(t, err, "should create testing WGIface interface")
|
||||||
defer wgInterface.Close()
|
defer wgInterface.Close()
|
||||||
|
|
||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
|
||||||
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.GetAddress().IP.String())
|
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String())
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
|
|
||||||
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
if testCase.shouldRouteToWireguard {
|
if testCase.shouldRouteToWireguard {
|
||||||
require.Equal(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||||
} else {
|
} else {
|
||||||
require.NotEqual(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to a different interface")
|
require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetAddress().IP.String())
|
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String())
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
|
|
||||||
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
|
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
|
||||||
|
|||||||
14
client/internal/stdnet/discover.go
Normal file
14
client/internal/stdnet/discover.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package stdnet
|
||||||
|
|
||||||
|
import "github.com/pion/transport/v2"
|
||||||
|
|
||||||
|
// ExternalIFaceDiscover provide an option for external services (mobile)
|
||||||
|
// to collect network interface information
|
||||||
|
type ExternalIFaceDiscover interface {
|
||||||
|
// IFaces return with the description of the interfaces
|
||||||
|
IFaces() (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type iFaceDiscover interface {
|
||||||
|
iFaces() ([]*transport.Interface, error)
|
||||||
|
}
|
||||||
95
client/internal/stdnet/discover_mobile.go
Normal file
95
client/internal/stdnet/discover_mobile.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package stdnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v2"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mobileIFaceDiscover struct {
|
||||||
|
externalDiscover ExternalIFaceDiscover
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMobileIFaceDiscover(externalDiscover ExternalIFaceDiscover) *mobileIFaceDiscover {
|
||||||
|
return &mobileIFaceDiscover{
|
||||||
|
externalDiscover: externalDiscover,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mobileIFaceDiscover) iFaces() ([]*transport.Interface, error) {
|
||||||
|
ifacesString, err := m.externalDiscover.IFaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
interfaces := m.parseInterfacesString(ifacesString)
|
||||||
|
return interfaces, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mobileIFaceDiscover) parseInterfacesString(interfaces string) []*transport.Interface {
|
||||||
|
ifs := []*transport.Interface{}
|
||||||
|
|
||||||
|
for _, iface := range strings.Split(interfaces, "\n") {
|
||||||
|
if strings.TrimSpace(iface) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Split(iface, "|")
|
||||||
|
if len(fields) != 2 {
|
||||||
|
log.Warnf("parseInterfacesString: unable to split %q", iface)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var name string
|
||||||
|
var index, mtu int
|
||||||
|
var up, broadcast, loopback, pointToPoint, multicast bool
|
||||||
|
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
|
||||||
|
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newIf := net.Interface{
|
||||||
|
Name: name,
|
||||||
|
Index: index,
|
||||||
|
MTU: mtu,
|
||||||
|
}
|
||||||
|
if up {
|
||||||
|
newIf.Flags |= net.FlagUp
|
||||||
|
}
|
||||||
|
if broadcast {
|
||||||
|
newIf.Flags |= net.FlagBroadcast
|
||||||
|
}
|
||||||
|
if loopback {
|
||||||
|
newIf.Flags |= net.FlagLoopback
|
||||||
|
}
|
||||||
|
if pointToPoint {
|
||||||
|
newIf.Flags |= net.FlagPointToPoint
|
||||||
|
}
|
||||||
|
if multicast {
|
||||||
|
newIf.Flags |= net.FlagMulticast
|
||||||
|
}
|
||||||
|
|
||||||
|
ifc := transport.NewInterface(newIf)
|
||||||
|
|
||||||
|
addrs := strings.Trim(fields[1], " \n")
|
||||||
|
foundAddress := false
|
||||||
|
for _, addr := range strings.Split(addrs, " ") {
|
||||||
|
ip, ipNet, err := net.ParseCIDR(addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("%s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ipNet.IP = ip
|
||||||
|
ifc.AddAddress(ipNet)
|
||||||
|
foundAddress = true
|
||||||
|
}
|
||||||
|
if foundAddress {
|
||||||
|
ifs = append(ifs, ifc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ifs
|
||||||
|
}
|
||||||
68
client/internal/stdnet/discover_mobile_test.go
Normal file
68
client/internal/stdnet/discover_mobile_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package stdnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_parseInterfacesString(t *testing.T) {
|
||||||
|
testData := []struct {
|
||||||
|
name string
|
||||||
|
index int
|
||||||
|
mtu int
|
||||||
|
up bool
|
||||||
|
broadcast bool
|
||||||
|
loopBack bool
|
||||||
|
pointToPoint bool
|
||||||
|
multicast bool
|
||||||
|
addr string
|
||||||
|
}{
|
||||||
|
{"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"},
|
||||||
|
{"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"},
|
||||||
|
{"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var exampleString string
|
||||||
|
for _, d := range testData {
|
||||||
|
exampleString = fmt.Sprintf("%s\n%s %d %d %t %t %t %t %t | %s", exampleString,
|
||||||
|
d.name,
|
||||||
|
d.index,
|
||||||
|
d.mtu,
|
||||||
|
d.up,
|
||||||
|
d.broadcast,
|
||||||
|
d.loopBack,
|
||||||
|
d.pointToPoint,
|
||||||
|
d.multicast,
|
||||||
|
d.addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
d := mobileIFaceDiscover{}
|
||||||
|
nets := d.parseInterfacesString(exampleString)
|
||||||
|
if len(nets) == 0 {
|
||||||
|
t.Fatalf("failed to parse interfaces")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, net := range nets {
|
||||||
|
if net.MTU != testData[i].mtu {
|
||||||
|
t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if net.Interface.Name != testData[i].name {
|
||||||
|
t.Errorf("invalid interface name: %s, expected: %s", net.Interface.Name, testData[i].name)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := net.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addr) == 0 {
|
||||||
|
t.Errorf("invalid address parsing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr[0].String() != testData[i].addr {
|
||||||
|
t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
36
client/internal/stdnet/discover_pion.go
Normal file
36
client/internal/stdnet/discover_pion.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package stdnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pionDiscover struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d pionDiscover) iFaces() ([]*transport.Interface, error) {
|
||||||
|
ifs := []*transport.Interface{}
|
||||||
|
|
||||||
|
oifs, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, oif := range oifs {
|
||||||
|
ifc := transport.NewInterface(oif)
|
||||||
|
|
||||||
|
addrs, err := oif.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ifc.AddAddress(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs = append(ifs, ifc)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ifs, nil
|
||||||
|
}
|
||||||
40
client/internal/stdnet/filter.go
Normal file
40
client/internal/stdnet/filter.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package stdnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InterfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
|
||||||
|
// to avoid building tunnel over them.
|
||||||
|
func InterfaceFilter(disallowList []string) func(string) bool {
|
||||||
|
|
||||||
|
return func(iFace string) bool {
|
||||||
|
|
||||||
|
if strings.HasPrefix(iFace, "lo") {
|
||||||
|
// hardcoded loopback check to support already installed agents
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range disallowList {
|
||||||
|
if strings.HasPrefix(iFace, s) {
|
||||||
|
log.Debugf("ignoring interface %s - it is not allowed", iFace)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// look for unlisted WireGuard interfaces
|
||||||
|
wg, err := wgctrl.New()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("trying to create a wgctrl client failed with: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = wg.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = wg.Device(iFace)
|
||||||
|
return err != nil
|
||||||
|
}
|
||||||
|
}
|
||||||
97
client/internal/stdnet/stdnet.go
Normal file
97
client/internal/stdnet/stdnet.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
// Package stdnet is an extension of the pion's stdnet.
|
||||||
|
// With it the list of the interface can come from external source.
|
||||||
|
// More info: https://github.com/golang/go/issues/40569
|
||||||
|
package stdnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v2"
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Net is an implementation of the net.Net interface
|
||||||
|
// based on functions of the standard net package.
|
||||||
|
type Net struct {
|
||||||
|
stdnet.Net
|
||||||
|
interfaces []*transport.Interface
|
||||||
|
iFaceDiscover iFaceDiscover
|
||||||
|
// interfaceFilter should return true if the given interfaceName is allowed
|
||||||
|
interfaceFilter func(interfaceName string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNetWithDiscover creates a new StdNet instance.
|
||||||
|
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||||
|
n := &Net{
|
||||||
|
iFaceDiscover: newMobileIFaceDiscover(iFaceDiscover),
|
||||||
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
|
}
|
||||||
|
return n, n.UpdateInterfaces()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNet creates a new StdNet instance.
|
||||||
|
func NewNet(disallowList []string) (*Net, error) {
|
||||||
|
n := &Net{
|
||||||
|
iFaceDiscover: pionDiscover{},
|
||||||
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
|
}
|
||||||
|
return n, n.UpdateInterfaces()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateInterfaces updates the internal list of network interfaces
|
||||||
|
// and associated addresses filtering them by name.
|
||||||
|
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
||||||
|
// wasn't specified.
|
||||||
|
func (n *Net) UpdateInterfaces() (err error) {
|
||||||
|
allIfaces, err := n.iFaceDiscover.iFaces()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n.interfaces = n.filterInterfaces(allIfaces)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Interfaces returns a slice of interfaces which are available on the
|
||||||
|
// system
|
||||||
|
func (n *Net) Interfaces() ([]*transport.Interface, error) {
|
||||||
|
return n.interfaces, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InterfaceByIndex returns the interface specified by index.
|
||||||
|
//
|
||||||
|
// On Solaris, it returns one of the logical network interfaces
|
||||||
|
// sharing the logical data link; for more precision use
|
||||||
|
// InterfaceByName.
|
||||||
|
func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) {
|
||||||
|
for _, ifc := range n.interfaces {
|
||||||
|
if ifc.Index == index {
|
||||||
|
return ifc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InterfaceByName returns the interface specified by name.
|
||||||
|
func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
|
||||||
|
for _, ifc := range n.interfaces {
|
||||||
|
if ifc.Name == name {
|
||||||
|
return ifc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.Interface {
|
||||||
|
if n.interfaceFilter == nil {
|
||||||
|
return interfaces
|
||||||
|
}
|
||||||
|
result := []*transport.Interface{}
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
if n.interfaceFilter(iface.Name) {
|
||||||
|
result = append(result, iface)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -6,4 +6,4 @@
|
|||||||
#define EXPAND(x) STRINGIZE(x)
|
#define EXPAND(x) STRINGIZE(x)
|
||||||
CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml
|
CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml
|
||||||
7 ICON ui/netbird.ico
|
7 ICON ui/netbird.ico
|
||||||
wireguard.dll RCDATA wireguard.dll
|
wintun.dll RCDATA wintun.dll
|
||||||
|
|||||||
@@ -3,20 +3,19 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server for service control.
|
// Server for service control.
|
||||||
@@ -34,7 +33,7 @@ type Server struct {
|
|||||||
config *internal.Config
|
config *internal.Config
|
||||||
proto.UnimplementedDaemonServiceServer
|
proto.UnimplementedDaemonServiceServer
|
||||||
|
|
||||||
statusRecorder *nbStatus.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauthAuthFlow struct {
|
type oauthAuthFlow struct {
|
||||||
@@ -77,9 +76,9 @@ func (s *Server) Start() error {
|
|||||||
|
|
||||||
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
|
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
|
||||||
// on failure we return error to retry
|
// on failure we return error to retry
|
||||||
config, err := internal.ReadConfig(s.latestConfigInput)
|
config, err := internal.UpdateConfig(s.latestConfigInput)
|
||||||
if errorStatus, ok := gstatus.FromError(err); ok && errorStatus.Code() == codes.NotFound {
|
if errorStatus, ok := gstatus.FromError(err); ok && errorStatus.Code() == codes.NotFound {
|
||||||
config, err = internal.GetConfig(s.latestConfigInput)
|
s.config, err = internal.UpdateOrCreateConfig(s.latestConfigInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("unable to create configuration file: %v", err)
|
log.Warnf("unable to create configuration file: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -97,11 +96,13 @@ func (s *Server) Start() error {
|
|||||||
s.config = config
|
s.config = config
|
||||||
|
|
||||||
if s.statusRecorder == nil {
|
if s.statusRecorder == nil {
|
||||||
s.statusRecorder = nbStatus.NewRecorder()
|
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||||
|
} else {
|
||||||
|
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
|
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil); err != nil {
|
||||||
log.Errorf("init connections: %v", err)
|
log.Errorf("init connections: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -182,7 +183,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
|
|
||||||
inputConfig.PreSharedKey = &msg.PreSharedKey
|
inputConfig.PreSharedKey = &msg.PreSharedKey
|
||||||
|
|
||||||
config, err := internal.GetConfig(inputConfig)
|
config, err := internal.UpdateOrCreateConfig(inputConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -205,7 +206,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
state.Set(internal.StatusConnecting)
|
state.Set(internal.StatusConnecting)
|
||||||
|
|
||||||
if msg.SetupKey == "" {
|
if msg.SetupKey == "" {
|
||||||
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config)
|
providerConfig, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.Set(internal.StatusLoginFailed)
|
state.Set(internal.StatusLoginFailed)
|
||||||
s, ok := gstatus.FromError(err)
|
s, ok := gstatus.FromError(err)
|
||||||
@@ -222,12 +223,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostedClient := internal.NewHostedDeviceFlow(
|
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||||
providerConfig.ProviderConfig.Audience,
|
|
||||||
providerConfig.ProviderConfig.ClientID,
|
|
||||||
providerConfig.ProviderConfig.TokenEndpoint,
|
|
||||||
providerConfig.ProviderConfig.DeviceAuthEndpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
|
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
|
||||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||||
@@ -343,7 +339,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
|||||||
s.oauthAuthFlow.expiresAt = time.Now()
|
s.oauthAuthFlow.expiresAt = time.Now()
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
||||||
if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.AccessToken); err != nil {
|
if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.GetTokenToUse()); err != nil {
|
||||||
state.Set(loginStatus)
|
state.Set(loginStatus)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -387,11 +383,13 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.statusRecorder == nil {
|
if s.statusRecorder == nil {
|
||||||
s.statusRecorder = nbStatus.NewRecorder()
|
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
|
||||||
|
} else {
|
||||||
|
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
|
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil); err != nil {
|
||||||
log.Errorf("run client connection: %v", err)
|
log.Errorf("run client connection: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -428,10 +426,12 @@ func (s *Server) Status(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: system.NetbirdVersion()}
|
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
|
||||||
|
|
||||||
if s.statusRecorder == nil {
|
if s.statusRecorder == nil {
|
||||||
s.statusRecorder = nbStatus.NewRecorder()
|
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
|
||||||
|
} else {
|
||||||
|
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if msg.GetFullPeerStatus {
|
if msg.GetFullPeerStatus {
|
||||||
@@ -477,7 +477,7 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtoFullStatus(fullStatus nbStatus.FullStatus) *proto.FullStatus {
|
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||||
pbFullStatus := proto.FullStatus{
|
pbFullStatus := proto.FullStatus{
|
||||||
ManagementState: &proto.ManagementState{},
|
ManagementState: &proto.ManagementState{},
|
||||||
SignalState: &proto.SignalState{},
|
SignalState: &proto.SignalState{},
|
||||||
@@ -500,7 +500,7 @@ func toProtoFullStatus(fullStatus nbStatus.FullStatus) *proto.FullStatus {
|
|||||||
pbPeerState := &proto.PeerState{
|
pbPeerState := &proto.PeerState{
|
||||||
IP: peerState.IP,
|
IP: peerState.IP,
|
||||||
PubKey: peerState.PubKey,
|
PubKey: peerState.PubKey,
|
||||||
ConnStatus: peerState.ConnStatus,
|
ConnStatus: peerState.ConnStatus.String(),
|
||||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||||
Relayed: peerState.Relayed,
|
Relayed: peerState.Relayed,
|
||||||
Direct: peerState.Direct,
|
Direct: peerState.Direct,
|
||||||
|
|||||||
@@ -1,241 +0,0 @@
|
|||||||
package status
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
|
||||||
type PeerState struct {
|
|
||||||
IP string
|
|
||||||
PubKey string
|
|
||||||
FQDN string
|
|
||||||
ConnStatus string
|
|
||||||
ConnStatusUpdate time.Time
|
|
||||||
Relayed bool
|
|
||||||
Direct bool
|
|
||||||
LocalIceCandidateType string
|
|
||||||
RemoteIceCandidateType string
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalPeerState contains the latest state of the local peer
|
|
||||||
type LocalPeerState struct {
|
|
||||||
IP string
|
|
||||||
PubKey string
|
|
||||||
KernelInterface bool
|
|
||||||
FQDN string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignalState contains the latest state of a signal connection
|
|
||||||
type SignalState struct {
|
|
||||||
URL string
|
|
||||||
Connected bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// ManagementState contains the latest state of a management connection
|
|
||||||
type ManagementState struct {
|
|
||||||
URL string
|
|
||||||
Connected bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// FullStatus contains the full state held by the Status instance
|
|
||||||
type FullStatus struct {
|
|
||||||
Peers []PeerState
|
|
||||||
ManagementState ManagementState
|
|
||||||
SignalState SignalState
|
|
||||||
LocalPeerState LocalPeerState
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status holds a state of peers, signal and management connections
|
|
||||||
type Status struct {
|
|
||||||
mux sync.Mutex
|
|
||||||
peers map[string]PeerState
|
|
||||||
changeNotify map[string]chan struct{}
|
|
||||||
signal SignalState
|
|
||||||
management ManagementState
|
|
||||||
localPeer LocalPeerState
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRecorder returns a new Status instance
|
|
||||||
func NewRecorder() *Status {
|
|
||||||
return &Status{
|
|
||||||
peers: make(map[string]PeerState),
|
|
||||||
changeNotify: make(map[string]chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPeer adds peer to Daemon status map
|
|
||||||
func (d *Status) AddPeer(peerPubKey string) error {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
_, ok := d.peers[peerPubKey]
|
|
||||||
if ok {
|
|
||||||
return errors.New("peer already exist")
|
|
||||||
}
|
|
||||||
d.peers[peerPubKey] = PeerState{PubKey: peerPubKey}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeer adds peer to Daemon status map
|
|
||||||
func (d *Status) GetPeer(peerPubKey string) (PeerState, error) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
state, ok := d.peers[peerPubKey]
|
|
||||||
if !ok {
|
|
||||||
return PeerState{}, errors.New("peer not found")
|
|
||||||
}
|
|
||||||
return state, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePeer removes peer from Daemon status map
|
|
||||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
_, ok := d.peers[peerPubKey]
|
|
||||||
if ok {
|
|
||||||
delete(d.peers, peerPubKey)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.New("no peer with to remove")
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeerState updates peer status
|
|
||||||
func (d *Status) UpdatePeerState(receivedState PeerState) error {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
peerState, ok := d.peers[receivedState.PubKey]
|
|
||||||
if !ok {
|
|
||||||
return errors.New("peer doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
if receivedState.IP != "" {
|
|
||||||
peerState.IP = receivedState.IP
|
|
||||||
}
|
|
||||||
|
|
||||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
|
||||||
peerState.ConnStatus = receivedState.ConnStatus
|
|
||||||
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
|
|
||||||
peerState.Direct = receivedState.Direct
|
|
||||||
peerState.Relayed = receivedState.Relayed
|
|
||||||
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
|
|
||||||
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
|
|
||||||
}
|
|
||||||
|
|
||||||
d.peers[receivedState.PubKey] = peerState
|
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeerFQDN update peer's state fqdn only
|
|
||||||
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
peerState, ok := d.peers[peerPubKey]
|
|
||||||
if !ok {
|
|
||||||
return errors.New("peer doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState.FQDN = fqdn
|
|
||||||
d.peers[peerPubKey] = peerState
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
|
||||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
ch, found := d.changeNotify[peer]
|
|
||||||
if !found || ch == nil {
|
|
||||||
ch = make(chan struct{})
|
|
||||||
d.changeNotify[peer] = ch
|
|
||||||
}
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateLocalPeerState updates local peer status
|
|
||||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
d.localPeer = localPeerState
|
|
||||||
}
|
|
||||||
|
|
||||||
// CleanLocalPeerState cleans local peer status
|
|
||||||
func (d *Status) CleanLocalPeerState() {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
d.localPeer = LocalPeerState{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
|
||||||
func (d *Status) MarkManagementDisconnected(managementURL string) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
d.management = ManagementState{
|
|
||||||
URL: managementURL,
|
|
||||||
Connected: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkManagementConnected sets ManagementState to connected
|
|
||||||
func (d *Status) MarkManagementConnected(managementURL string) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
d.management = ManagementState{
|
|
||||||
URL: managementURL,
|
|
||||||
Connected: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkSignalDisconnected sets SignalState to disconnected
|
|
||||||
func (d *Status) MarkSignalDisconnected(signalURL string) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
d.signal = SignalState{
|
|
||||||
signalURL,
|
|
||||||
false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkSignalConnected sets SignalState to connected
|
|
||||||
func (d *Status) MarkSignalConnected(signalURL string) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
d.signal = SignalState{
|
|
||||||
signalURL,
|
|
||||||
true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFullStatus gets full status
|
|
||||||
func (d *Status) GetFullStatus() FullStatus {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
fullStatus := FullStatus{
|
|
||||||
ManagementState: d.management,
|
|
||||||
SignalState: d.signal,
|
|
||||||
LocalPeerState: d.localPeer,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, status := range d.peers {
|
|
||||||
fullStatus.Peers = append(fullStatus.Peers, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fullStatus
|
|
||||||
}
|
|
||||||
@@ -1,243 +0,0 @@
|
|||||||
package status
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAddPeer(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
status := NewRecorder()
|
|
||||||
err := status.AddPeer(key)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
_, exists := status.peers[key]
|
|
||||||
assert.True(t, exists, "value was found")
|
|
||||||
|
|
||||||
err = status.AddPeer(key)
|
|
||||||
|
|
||||||
assert.Error(t, err, "should return error on duplicate")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPeer(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
status := NewRecorder()
|
|
||||||
err := status.AddPeer(key)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
peerStatus, err := status.GetPeer(key)
|
|
||||||
assert.NoError(t, err, "shouldn't return error on getting peer")
|
|
||||||
|
|
||||||
assert.Equal(t, key, peerStatus.PubKey, "retrieved public key should match")
|
|
||||||
|
|
||||||
_, err = status.GetPeer("non_existing_key")
|
|
||||||
assert.Error(t, err, "should return error when peer doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdatePeerState(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
ip := "10.10.10.10"
|
|
||||||
status := NewRecorder()
|
|
||||||
peerState := PeerState{
|
|
||||||
PubKey: key,
|
|
||||||
}
|
|
||||||
|
|
||||||
status.peers[key] = peerState
|
|
||||||
|
|
||||||
peerState.IP = ip
|
|
||||||
|
|
||||||
err := status.UpdatePeerState(peerState)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
state, exists := status.peers[key]
|
|
||||||
assert.True(t, exists, "state should be found")
|
|
||||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
fqdn := "peer-a.netbird.local"
|
|
||||||
status := NewRecorder()
|
|
||||||
peerState := PeerState{
|
|
||||||
PubKey: key,
|
|
||||||
}
|
|
||||||
|
|
||||||
status.peers[key] = peerState
|
|
||||||
|
|
||||||
err := status.UpdatePeerFQDN(key, fqdn)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
state, exists := status.peers[key]
|
|
||||||
assert.True(t, exists, "state should be found")
|
|
||||||
assert.Equal(t, fqdn, state.FQDN, "fqdn should be equal")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
ip := "10.10.10.10"
|
|
||||||
status := NewRecorder()
|
|
||||||
peerState := PeerState{
|
|
||||||
PubKey: key,
|
|
||||||
}
|
|
||||||
|
|
||||||
status.peers[key] = peerState
|
|
||||||
|
|
||||||
ch := status.GetPeerStateChangeNotifier(key)
|
|
||||||
assert.NotNil(t, ch, "channel shouldn't be nil")
|
|
||||||
|
|
||||||
peerState.IP = ip
|
|
||||||
|
|
||||||
err := status.UpdatePeerState(peerState)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ch:
|
|
||||||
default:
|
|
||||||
t.Errorf("channel wasn't closed after update")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRemovePeer(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
status := NewRecorder()
|
|
||||||
peerState := PeerState{
|
|
||||||
PubKey: key,
|
|
||||||
}
|
|
||||||
|
|
||||||
status.peers[key] = peerState
|
|
||||||
|
|
||||||
err := status.RemovePeer(key)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
_, exists := status.peers[key]
|
|
||||||
assert.False(t, exists, "state value shouldn't be found")
|
|
||||||
|
|
||||||
err = status.RemovePeer("not existing")
|
|
||||||
assert.Error(t, err, "should return error when peer doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateLocalPeerState(t *testing.T) {
|
|
||||||
localPeerState := LocalPeerState{
|
|
||||||
IP: "10.10.10.10",
|
|
||||||
PubKey: "abc",
|
|
||||||
KernelInterface: false,
|
|
||||||
}
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
status.UpdateLocalPeerState(localPeerState)
|
|
||||||
|
|
||||||
assert.Equal(t, localPeerState, status.localPeer, "local peer status should be equal")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCleanLocalPeerState(t *testing.T) {
|
|
||||||
emptyLocalPeerState := LocalPeerState{}
|
|
||||||
localPeerState := LocalPeerState{
|
|
||||||
IP: "10.10.10.10",
|
|
||||||
PubKey: "abc",
|
|
||||||
KernelInterface: false,
|
|
||||||
}
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
status.localPeer = localPeerState
|
|
||||||
|
|
||||||
status.CleanLocalPeerState()
|
|
||||||
|
|
||||||
assert.Equal(t, emptyLocalPeerState, status.localPeer, "local peer status should be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateSignalState(t *testing.T) {
|
|
||||||
url := "https://signal"
|
|
||||||
var tests = []struct {
|
|
||||||
name string
|
|
||||||
connected bool
|
|
||||||
want SignalState
|
|
||||||
}{
|
|
||||||
{"should mark as connected", true, SignalState{
|
|
||||||
|
|
||||||
URL: url,
|
|
||||||
Connected: true,
|
|
||||||
}},
|
|
||||||
{"should mark as disconnected", false, SignalState{
|
|
||||||
URL: url,
|
|
||||||
Connected: false,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
if test.connected {
|
|
||||||
status.MarkSignalConnected(url)
|
|
||||||
} else {
|
|
||||||
status.MarkSignalDisconnected(url)
|
|
||||||
}
|
|
||||||
assert.Equal(t, test.want, status.signal, "signal status should be equal")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateManagementState(t *testing.T) {
|
|
||||||
url := "https://management"
|
|
||||||
var tests = []struct {
|
|
||||||
name string
|
|
||||||
connected bool
|
|
||||||
want ManagementState
|
|
||||||
}{
|
|
||||||
{"should mark as connected", true, ManagementState{
|
|
||||||
|
|
||||||
URL: url,
|
|
||||||
Connected: true,
|
|
||||||
}},
|
|
||||||
{"should mark as disconnected", false, ManagementState{
|
|
||||||
URL: url,
|
|
||||||
Connected: false,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
if test.connected {
|
|
||||||
status.MarkManagementConnected(url)
|
|
||||||
} else {
|
|
||||||
status.MarkManagementDisconnected(url)
|
|
||||||
}
|
|
||||||
assert.Equal(t, test.want, status.management, "signal status should be equal")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetFullStatus(t *testing.T) {
|
|
||||||
key1 := "abc"
|
|
||||||
key2 := "def"
|
|
||||||
managementState := ManagementState{
|
|
||||||
URL: "https://signal",
|
|
||||||
Connected: true,
|
|
||||||
}
|
|
||||||
signalState := SignalState{
|
|
||||||
URL: "https://signal",
|
|
||||||
Connected: true,
|
|
||||||
}
|
|
||||||
peerState1 := PeerState{
|
|
||||||
PubKey: key1,
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState2 := PeerState{
|
|
||||||
PubKey: key2,
|
|
||||||
}
|
|
||||||
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
status.management = managementState
|
|
||||||
status.signal = signalState
|
|
||||||
status.peers[key1] = peerState1
|
|
||||||
status.peers[key2] = peerState2
|
|
||||||
|
|
||||||
fullStatus := status.GetFullStatus()
|
|
||||||
|
|
||||||
assert.Equal(t, managementState, fullStatus.ManagementState, "management status should be equal")
|
|
||||||
assert.Equal(t, signalState, fullStatus.SignalState, "signal status should be equal")
|
|
||||||
assert.ElementsMatch(t, []PeerState{peerState1, peerState2}, fullStatus.Peers, "peers states should match")
|
|
||||||
}
|
|
||||||
@@ -2,15 +2,17 @@ package system
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// this is the wiretrustee version
|
// DeviceNameCtxKey context key for device name
|
||||||
// will be replaced with the release version when using goreleaser
|
const DeviceNameCtxKey = "deviceName"
|
||||||
var version = "development"
|
|
||||||
|
|
||||||
//Info is an object that contains machine information
|
// Info is an object that contains machine information
|
||||||
// Most of the code is taken from https://github.com/matishsiao/goInfo
|
// Most of the code is taken from https://github.com/matishsiao/goInfo
|
||||||
type Info struct {
|
type Info struct {
|
||||||
GoOS string
|
GoOS string
|
||||||
@@ -25,11 +27,6 @@ type Info struct {
|
|||||||
UIVersion string
|
UIVersion string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NetbirdVersion returns the Netbird version
|
|
||||||
func NetbirdVersion() string {
|
|
||||||
return version
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||||
func extractUserAgent(ctx context.Context) string {
|
func extractUserAgent(ctx context.Context) string {
|
||||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||||
@@ -48,5 +45,5 @@ func extractUserAgent(ctx context.Context) string {
|
|||||||
|
|
||||||
// GetDesktopUIUserAgent returns the Desktop ui user agent
|
// GetDesktopUIUserAgent returns the Desktop ui user agent
|
||||||
func GetDesktopUIUserAgent() string {
|
func GetDesktopUIUserAgent() string {
|
||||||
return "netbird-desktop-ui/" + NetbirdVersion()
|
return "netbird-desktop-ui/" + version.NetbirdVersion()
|
||||||
}
|
}
|
||||||
|
|||||||
63
client/system/info_android.go
Normal file
63
client/system/info_android.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
//go:build android
|
||||||
|
// +build android
|
||||||
|
|
||||||
|
package system
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetInfo retrieves and parses the system information
|
||||||
|
func GetInfo(ctx context.Context) *Info {
|
||||||
|
kernel := "android"
|
||||||
|
osInfo := uname()
|
||||||
|
if len(osInfo) == 2 {
|
||||||
|
kernel = osInfo[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
gio := &Info{Kernel: kernel, Core: osVersion(), Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||||
|
gio.Hostname = extractDeviceName(ctx)
|
||||||
|
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||||
|
gio.UIVersion = extractUserAgent(ctx)
|
||||||
|
|
||||||
|
return gio
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractDeviceName(ctx context.Context) string {
|
||||||
|
v, ok := ctx.Value(DeviceNameCtxKey).(string)
|
||||||
|
if !ok {
|
||||||
|
return "android"
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func uname() []string {
|
||||||
|
res := run("/system/bin/uname", "-a")
|
||||||
|
return strings.Split(res, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func osVersion() string {
|
||||||
|
return run("/system/bin/getprop", "ro.build.version.release")
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(name string, arg ...string) string {
|
||||||
|
cmd := exec.Command(name, arg...)
|
||||||
|
cmd.Stdin = strings.NewReader("some")
|
||||||
|
var out bytes.Buffer
|
||||||
|
var stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &out
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
err := cmd.Run()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("getInfo: %s", err)
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
@@ -4,12 +4,16 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
@@ -22,14 +26,14 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
|
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
|
||||||
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
|
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
|
||||||
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
|
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
|
||||||
version, err := exec.Command("sw_vers", "-productVersion").Output()
|
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
|
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
|
||||||
version = []byte(release)
|
swVersion = []byte(release)
|
||||||
}
|
}
|
||||||
gio := &Info{Kernel: sysName, OSVersion: strings.TrimSpace(string(version)), Core: release, Platform: machine, OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
gio := &Info{Kernel: sysName, OSVersion: strings.TrimSpace(string(swVersion)), Core: release, Platform: machine, OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||||
gio.Hostname, _ = os.Hostname()
|
gio.Hostname, _ = os.Hostname()
|
||||||
gio.WiretrusteeVersion = NetbirdVersion()
|
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||||
gio.UIVersion = extractUserAgent(ctx)
|
gio.UIVersion = extractUserAgent(ctx)
|
||||||
|
|
||||||
return gio
|
return gio
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
@@ -23,7 +25,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
osInfo := strings.Split(osStr, " ")
|
osInfo := strings.Split(osStr, " ")
|
||||||
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||||
gio.Hostname, _ = os.Hostname()
|
gio.Hostname, _ = os.Hostname()
|
||||||
gio.WiretrusteeVersion = NetbirdVersion()
|
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||||
gio.UIVersion = extractUserAgent(ctx)
|
gio.UIVersion = extractUserAgent(ctx)
|
||||||
|
|
||||||
return gio
|
return gio
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
//go:build !android
|
||||||
|
// +build !android
|
||||||
|
|
||||||
package system
|
package system
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -9,6 +12,8 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
@@ -46,7 +51,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
}
|
}
|
||||||
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: osInfo[2], OS: osName, OSVersion: osVer, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: osInfo[2], OS: osName, OSVersion: osVer, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||||
gio.Hostname, _ = os.Hostname()
|
gio.Hostname, _ = os.Hostname()
|
||||||
gio.WiretrusteeVersion = NetbirdVersion()
|
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||||
gio.UIVersion = extractUserAgent(ctx)
|
gio.UIVersion = extractUserAgent(ctx)
|
||||||
|
|
||||||
return gio
|
return gio
|
||||||
|
|||||||
@@ -3,10 +3,13 @@ package system
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/windows/registry"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
@@ -14,7 +17,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
ver := getOSVersion()
|
ver := getOSVersion()
|
||||||
gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
|
||||||
gio.Hostname, _ = os.Hostname()
|
gio.Hostname, _ = os.Hostname()
|
||||||
gio.WiretrusteeVersion = NetbirdVersion()
|
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||||
gio.UIVersion = extractUserAgent(ctx)
|
gio.UIVersion = extractUserAgent(ctx)
|
||||||
|
|
||||||
return gio
|
return gio
|
||||||
@@ -32,7 +35,7 @@ func getOSVersion() string {
|
|||||||
log.Error(deferErr)
|
log.Error(deferErr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
|
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
_ "embed"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -17,25 +18,22 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"fyne.io/fyne/v2"
|
||||||
|
"fyne.io/fyne/v2/app"
|
||||||
|
"fyne.io/fyne/v2/dialog"
|
||||||
|
"fyne.io/fyne/v2/widget"
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
|
||||||
_ "embed"
|
|
||||||
|
|
||||||
"github.com/getlantern/systray"
|
"github.com/getlantern/systray"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
"fyne.io/fyne/v2"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"fyne.io/fyne/v2/app"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"fyne.io/fyne/v2/dialog"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"fyne.io/fyne/v2/widget"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -373,7 +371,7 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
systray.AddSeparator()
|
systray.AddSeparator()
|
||||||
s.mSettings = systray.AddMenuItem("Settings", "Settings of the application")
|
s.mSettings = systray.AddMenuItem("Settings", "Settings of the application")
|
||||||
systray.AddSeparator()
|
systray.AddSeparator()
|
||||||
v := systray.AddMenuItem("v"+system.NetbirdVersion(), "Client Version: "+system.NetbirdVersion())
|
v := systray.AddMenuItem("v"+version.NetbirdVersion(), "Client Version: "+version.NetbirdVersion())
|
||||||
v.Disable()
|
v.Disable()
|
||||||
systray.AddSeparator()
|
systray.AddSeparator()
|
||||||
s.mQuit = systray.AddMenuItem("Quit", "Quit the client app")
|
s.mQuit = systray.AddMenuItem("Quit", "Quit the client app")
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
ldir=$PWD
|
|
||||||
tmp_dir_path=$ldir/.distfiles
|
|
||||||
winnt=wireguard-nt.zip
|
|
||||||
download_file_path=$tmp_dir_path/$winnt
|
|
||||||
download_url=https://download.wireguard.com/wireguard-nt/wireguard-nt-0.10.1.zip
|
|
||||||
download_sha=772c0b1463d8d2212716f43f06f4594d880dea4f735165bd68e388fc41b81605
|
|
||||||
|
|
||||||
function resources_windows(){
|
|
||||||
cmd=$1
|
|
||||||
arch=$2
|
|
||||||
out=$3
|
|
||||||
docker run -i --rm -v $PWD:$PWD -w $PWD mstorsjo/llvm-mingw:latest $cmd -O coff -c 65001 -I $tmp_dir_path/wireguard-nt/bin/$arch -i resources.rc -o $out
|
|
||||||
}
|
|
||||||
|
|
||||||
mkdir -p $tmp_dir_path
|
|
||||||
curl -L#o $download_file_path.unverified $download_url
|
|
||||||
echo "$download_sha $download_file_path.unverified" | sha256sum -c
|
|
||||||
mv $download_file_path.unverified $download_file_path
|
|
||||||
|
|
||||||
mkdir -p .deps
|
|
||||||
unzip $download_file_path -d $tmp_dir_path
|
|
||||||
|
|
||||||
resources_windows i686-w64-mingw32-windres x86 resources_windows_386.syso
|
|
||||||
resources_windows aarch64-w64-mingw32-windres arm64 resources_windows_arm64.syso
|
|
||||||
resources_windows x86_64-w64-mingw32-windres amd64 resources_windows_amd64.syso
|
|
||||||
@@ -3,10 +3,13 @@ package encryption
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/crypto/nacl/box"
|
"golang.org/x/crypto/nacl/box"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const nonceSize = 24
|
||||||
|
|
||||||
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
|
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
|
||||||
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
||||||
// Wireguard keys are used for encryption
|
// Wireguard keys are used for encryption
|
||||||
@@ -26,8 +29,11 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
copy(nonce[:], encryptedMsg[:24])
|
if len(encryptedMsg) < nonceSize {
|
||||||
opened, ok := box.Open(nil, encryptedMsg[24:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
|
return nil, fmt.Errorf("invalid encrypted message lenght")
|
||||||
|
}
|
||||||
|
copy(nonce[:], encryptedMsg[:nonceSize])
|
||||||
|
opened, ok := box.Open(nil, encryptedMsg[nonceSize:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String())
|
return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String())
|
||||||
}
|
}
|
||||||
@@ -36,8 +42,8 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generates nonce of size 24
|
// Generates nonce of size 24
|
||||||
func genNonce() (*[24]byte, error) {
|
func genNonce() (*[nonceSize]byte, error) {
|
||||||
var nonce [24]byte
|
var nonce [nonceSize]byte
|
||||||
if _, err := rand.Read(nonce[:]); err != nil {
|
if _, err := rand.Read(nonce[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
51
formatter/formatter.go
Normal file
51
formatter/formatter.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package formatter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TextFormatter formats logs into text with included source code's path
|
||||||
|
type TextFormatter struct {
|
||||||
|
timestampFormat string
|
||||||
|
levelDesc []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTextFormatter create new MyTextFormatter instance
|
||||||
|
func NewTextFormatter() *TextFormatter {
|
||||||
|
return &TextFormatter{
|
||||||
|
levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"},
|
||||||
|
timestampFormat: time.RFC3339, // or RFC3339
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format renders a single log entry
|
||||||
|
func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||||
|
var fields string
|
||||||
|
keys := make([]string, 0, len(entry.Data))
|
||||||
|
for k, v := range entry.Data {
|
||||||
|
if k == "source" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keys) > 0 {
|
||||||
|
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
level := f.parseLevel(entry.Level)
|
||||||
|
|
||||||
|
return []byte(fmt.Sprintf("%s %s %s%s: %s\n", entry.Time.Format(f.timestampFormat), level, fields, entry.Data["source"], entry.Message)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *TextFormatter) parseLevel(level logrus.Level) string {
|
||||||
|
if len(f.levelDesc) < int(level) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.levelDesc[level]
|
||||||
|
}
|
||||||
26
formatter/formatter_test.go
Normal file
26
formatter/formatter_test.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package formatter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogMessageFormat(t *testing.T) {
|
||||||
|
|
||||||
|
someEntry := &logrus.Entry{
|
||||||
|
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
|
||||||
|
Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
|
||||||
|
Level: 3,
|
||||||
|
Message: "Some Message",
|
||||||
|
}
|
||||||
|
|
||||||
|
formatter := NewTextFormatter()
|
||||||
|
result, _ := formatter.Format(someEntry)
|
||||||
|
|
||||||
|
parsedString := string(result)
|
||||||
|
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
||||||
|
assert.Regexp(t, expectedString, parsedString)
|
||||||
|
}
|
||||||
61
formatter/hook.go
Normal file
61
formatter/hook.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package formatter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path"
|
||||||
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContextHook is a custom hook for add the source information for the entry
|
||||||
|
type ContextHook struct {
|
||||||
|
goModuleName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContextHook instantiate a new context hook
|
||||||
|
func NewContextHook() *ContextHook {
|
||||||
|
hook := &ContextHook{}
|
||||||
|
hook.goModuleName = hook.moduleName() + "/"
|
||||||
|
return hook
|
||||||
|
}
|
||||||
|
|
||||||
|
// Levels set the supported levels for this hook
|
||||||
|
func (hook ContextHook) Levels() []logrus.Level {
|
||||||
|
return logrus.AllLevels
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fire extend with the source information the entry.Data
|
||||||
|
func (hook ContextHook) Fire(entry *logrus.Entry) error {
|
||||||
|
src := hook.parseSrc(entry.Caller.File)
|
||||||
|
entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook ContextHook) moduleName() string {
|
||||||
|
info, ok := debug.ReadBuildInfo()
|
||||||
|
if ok && info.Main.Path != "" {
|
||||||
|
return info.Main.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
return "netbird"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hook ContextHook) parseSrc(filePath string) string {
|
||||||
|
netbirdPath := strings.SplitAfter(filePath, hook.goModuleName)
|
||||||
|
if len(netbirdPath) > 1 {
|
||||||
|
return netbirdPath[len(netbirdPath)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// in case of forked repo
|
||||||
|
netbirdPath = strings.SplitAfter(filePath, "netbird/")
|
||||||
|
if len(netbirdPath) > 1 {
|
||||||
|
return netbirdPath[len(netbirdPath)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// in case if log entry is come from external pkg
|
||||||
|
_, pkg := path.Split(path.Dir(filePath))
|
||||||
|
file := path.Base(filePath)
|
||||||
|
return fmt.Sprintf("%s/%s", pkg, file)
|
||||||
|
}
|
||||||
39
formatter/hook_test.go
Normal file
39
formatter/hook_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package formatter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilePathParsing(t *testing.T) {
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
filePath string
|
||||||
|
expectedFileName string
|
||||||
|
}{
|
||||||
|
// locally cloned repo
|
||||||
|
{
|
||||||
|
filePath: "/Users/user/Github/Netbird/netbird/formatter/formatter.go",
|
||||||
|
expectedFileName: "formatter/formatter.go",
|
||||||
|
},
|
||||||
|
// locally cloned repo with duplicated name in path
|
||||||
|
{
|
||||||
|
filePath: "/Users/user/netbird/repos/netbird/formatter/formatter.go",
|
||||||
|
expectedFileName: "formatter/formatter.go",
|
||||||
|
},
|
||||||
|
// locally cloned repo with renamed package root
|
||||||
|
{
|
||||||
|
filePath: "/Users/user/Github/MyOwnNetbirdClient/formatter/formatter.go",
|
||||||
|
expectedFileName: "formatter/formatter.go",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
hook := NewContextHook()
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
parsedString := hook.parseSrc(testCase.filePath)
|
||||||
|
assert.Equal(t, testCase.expectedFileName, parsedString, "Parsed filepath does not match expected for %s", testCase.filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user