Compare commits

...

158 Commits

Author SHA1 Message Date
Zoltan Papp
ca10541f50 Proposal for code cleaning 2023-04-11 17:06:55 +02:00
braginini
9b9f5fb64b Merge branch 'main' into feature/bind 2023-04-11 16:06:00 +02:00
Maycon Santos
3d3de0f1dd remove wintun.dll and use new sign-pipelines version 2023-04-11 15:27:45 +02:00
Misha Bragin
251f2d7bc2 Pass newly generated ID to network map when adding peer (#800) 2023-04-11 14:28:22 +02:00
braginini
94c6d8c55b Reduce log level for kernel module check 2023-04-11 11:21:40 +02:00
Maycon Santos
306e02d32b Update calculate server state (#796)
Refactored updateServerStates and calculateState

added some checks to ensure we are not sending connecting on context canceled

removed some state updates from the RunClient function
2023-04-10 18:22:25 +02:00
pascal-fischer
8375491708 Merge pull request #778 from netbirdio/fix/consistent_time_format_for_pat
fix/use_utc_for_time_operations
2023-04-10 18:11:41 +02:00
braginini
0f449f64ec Fix NPE in windows tunnel 2023-04-10 11:34:30 +02:00
braginini
a7953cef58 Correct bind error handling 2023-04-10 11:16:01 +02:00
braginini
2bf846b183 Close UAPI in Windows tun 2023-04-10 11:13:40 +02:00
Pascal Fischer
e197b89ac3 remove UTC from some not store related operations 2023-04-10 11:09:27 +02:00
Pascal Fischer
6aba28ccb7 remove UTC from some not store related operations 2023-04-10 10:54:23 +02:00
braginini
2a04175238 Fix windows tun GUID 2023-04-10 10:44:31 +02:00
Maycon Santos
84467c1531 test all and remove files 2023-04-09 01:11:48 +02:00
Maycon Santos
44b7afc06b use cmd 2023-04-09 00:58:15 +02:00
Maycon Santos
dea45633a6 store the output 2023-04-09 00:45:05 +02:00
Maycon Santos
14939202f0 add timeout and p 1 2023-04-09 00:17:04 +02:00
Maycon Santos
c5a8cc59a8 use id 2023-04-08 23:57:50 +02:00
Maycon Santos
3d3400ab43 use windwos latest 2023-04-08 23:54:14 +02:00
Maycon Santos
cdf5368d20 run with psexec 2023-04-08 23:52:12 +02:00
Maycon Santos
0803f16a49 debug update powershell.config.json 2023-04-08 22:20:21 +02:00
Maycon Santos
b6aafd8f09 debug update powershell.config.json 2023-04-08 22:19:11 +02:00
Maycon Santos
4c9fc9850c debug update powershell.config.json 2023-04-08 22:15:02 +02:00
Maycon Santos
0e9d5807d6 debug update powershell.config.json 2023-04-08 22:13:50 +02:00
Maycon Santos
b5a3248f9f debug remove powershell.config.json 2023-04-08 22:06:42 +02:00
Maycon Santos
5e1dfb28c0 debug 2023-04-08 21:25:33 +02:00
Maycon Santos
d1fe03a2d4 debug 2023-04-08 21:23:00 +02:00
Maycon Santos
bb453d334b remove DisableWinCompat file 2023-04-08 21:10:30 +02:00
Maycon Santos
56737bab6c remove DisableWinCompat file 2023-04-08 21:08:09 +02:00
Maycon Santos
37d20671a9 remove step 2023-04-08 20:58:39 +02:00
Maycon Santos
3f717eb759 Merge remote-tracking branch 'origin/feature/bind' into feature/bind 2023-04-08 20:51:51 +02:00
Maycon Santos
f9f8cbdcaa test windows 2019 2023-04-08 20:51:38 +02:00
braginini
a3849b978b Fix shoulduseproxy tests 2023-04-08 17:39:38 +02:00
braginini
1c071e4981 Revert "Only consider /16 network when using direct mode"
This reverts commit 189321f09d.
2023-04-08 17:38:13 +02:00
braginini
189321f09d Only consider /16 network when using direct mode 2023-04-08 17:37:48 +02:00
braginini
db5f931373 use direct mode only when in the same private /16 network 2023-04-08 17:22:53 +02:00
Maycon Santos
b556736b31 build cache 2023-04-08 13:29:47 +02:00
Maycon Santos
01e19a7c67 set unrestricted policy 2023-04-08 13:12:13 +02:00
Maycon Santos
9494cbdf24 import PSScheduledJob 2023-04-08 13:04:16 +02:00
Maycon Santos
9c09b13a25 import module in the same step 2023-04-08 12:56:23 +02:00
Maycon Santos
8bb999cf2a using https://github.com/mkellerman/Invoke-CommandAs 2023-04-08 12:53:03 +02:00
Maycon Santos
565b8ce1c7 use paexec and use always to build cache 2023-04-08 11:25:57 +02:00
Maycon Santos
e96a975737 extension 2023-04-08 11:14:48 +02:00
Maycon Santos
8585e3ccf7 test moving the bin to temp folder 2023-04-08 11:09:21 +02:00
Maycon Santos
cd002c6400 building a test bin 2023-04-08 11:03:31 +02:00
Maycon Santos
0629697db1 use sysinternals choco package 2023-04-08 10:59:29 +02:00
Maycon Santos
25a92a0052 check other command 2023-04-08 00:48:55 +02:00
Maycon Santos
9c51d85cb4 copy acl 2023-04-07 19:25:21 +02:00
Maycon Santos
9865179207 get acl 2023-04-07 19:17:42 +02:00
Maycon Santos
56f10085f4 dir out 2023-04-07 19:06:28 +02:00
Maycon Santos
a7574907ae use powershell command 2023-04-07 19:04:25 +02:00
Maycon Santos
71e81533bc list workspace dir 2023-04-07 18:45:52 +02:00
Maycon Santos
23b92e2615 add concurrency setting to workflows
it helps minimize github action views and stop outdated jobs
2023-04-07 18:41:05 +02:00
Maycon Santos
9158a4653a use workspace path 2023-04-07 18:28:22 +02:00
Maycon Santos
ccbf749171 psexec accept eula 2023-04-07 18:24:30 +02:00
Maycon Santos
dea7e8d4e7 download wintun and psexec
use rsrc tool to generate syso files
2023-04-07 18:20:08 +02:00
braginini
a0441e7d04 Commit unused code 2023-04-07 16:58:26 +02:00
braginini
9702946474 move dll to 2023-04-07 16:42:01 +02:00
braginini
e262f3536e move dll to 2023-04-07 16:34:26 +02:00
braginini
addfed3af0 upload artifact to root folder 2023-04-07 16:21:50 +02:00
braginini
bf723ec66f remove lazy dll load 2023-04-07 16:13:50 +02:00
braginini
10afc8eeb8 Try workflow with DLL 2023-04-07 16:04:02 +02:00
braginini
0b21e05a52 Revert "disable workflows"
This reverts commit 94c646f1e5.
2023-04-07 15:39:01 +02:00
braginini
94c646f1e5 disable workflows 2023-04-07 15:37:56 +02:00
braginini
4f7d34c5c7 try fixing ci/cd 2023-04-07 15:37:03 +02:00
braginini
0455e574b8 Add license comment for the copied dll code 2023-04-07 15:15:34 +02:00
braginini
965ba8837f Fix tun address assignment on windows 2023-04-07 15:11:21 +02:00
braginini
61146a51d0 Load wintun.dll 2023-04-07 14:21:54 +02:00
Maycon Santos
8f9826b207 Fix export path for certificate files (#794)
assign the value for NETBIRD_LETSENCRYPT_DOMAIN
in the base.setup.env file
2023-04-07 10:34:17 +02:00
braginini
4f8a156cb2 Load wintun lib 2023-04-06 18:54:06 +02:00
braginini
ff0b395fc5 Trying to fix wintun.dll 2023-04-06 18:45:44 +02:00
Zoltan Papp
0aad9169e9 Fix nil pointer exception (#790)
Nil pointer exception fix. The error handling was in wrong order.
2023-04-06 18:15:55 +02:00
braginini
237bfde1f2 Try loading wintun.dll 2023-04-06 16:54:57 +02:00
braginini
bfff0c36aa Fix windows tun device 2023-04-06 11:34:19 +02:00
braginini
41458a09e9 Use Bind implementation on Windows 2023-04-06 11:01:51 +02:00
Zoltan Papp
abd8287da8 Separate discover methods (#787) 2023-04-06 10:39:23 +02:00
Maycon Santos
1057cd211d Add scope and id token environment variables (#785) 2023-04-05 21:57:47 +02:00
Zoltan Papp
d3c49c71f2 Fix config reference in stdnet_android 2023-04-05 21:23:13 +02:00
braginini
49030ab71e Merge branch 'main' into feature/bind 2023-04-05 18:38:41 +02:00
braginini
7548780f8f Fix docker test workflow 2023-04-05 18:07:38 +02:00
braginini
277b65b833 Fix Windows compilation errros 2023-04-05 18:00:44 +02:00
braginini
071ad2b993 Fix Codacy and lint issues 2023-04-05 17:57:18 +02:00
Maycon Santos
32b345991a Support remote scope and use id token configuration (#784)
Some IDP requires different scope requests and
issue access tokens for different purposes

This change allow for remote configurable scopes
and the use of ID token
2023-04-05 17:46:34 +02:00
braginini
0e8a552334 Fix Codacy 2023-04-05 16:54:50 +02:00
braginini
005c4dd44a Make usage of WireGuard kernel configurable from env 2023-04-05 16:30:52 +02:00
Maycon Santos
e903522f8c Configurable port defaults from setup.env (#783)
Allow configuring management and signal ports from setup.env

Allow configuring Coturn range from setup.env
2023-04-05 15:22:06 +02:00
Maycon Santos
ea88ec6d27 Roolback configurable port defaults from setup.env 2023-04-05 11:42:14 +02:00
Maycon Santos
2be1a82f4a Configurable port defaults from setup.env
Allow configuring management and signal ports from setup.env

Allow configuring Coturn range from setup.env
2023-04-05 11:39:22 +02:00
braginini
367eff493a Ignore loopback interface 2023-04-04 17:49:28 +02:00
braginini
73a5bc33b3 Use bind no proxy when possible 2023-04-04 17:40:25 +02:00
braginini
87cbff1e7a Glue up all together 2023-04-04 16:50:35 +02:00
Maycon Santos
fe1ea4a2d0 Check multiple audience values (#781)
Some IDP use different audience for different clients. 
This update checks HTTP and Device authorization flow audience values.



---------

Co-authored-by: Givi Khojanashvili <gigovich@gmail.com>
2023-04-04 16:40:56 +02:00
Maycon Santos
f14f34cf2b Add token source and device flow audience variables (#780)
Supporting new dashboard option to configure a source token.

Adding configuration support for setting 
a different audience for device authorization flow.

fix custom id claim variable
2023-04-04 15:56:02 +02:00
braginini
24cc5c4ef2 Rename WireGuard proxy 2023-04-04 14:52:57 +02:00
Bethuel
109481e26d Use first available package manager (#782) 2023-04-04 14:26:17 +02:00
Bethuel
18098e7a7d Add single line installer (#775)
detect OS package manager
If a supported package manager is not available,
use binary installation

Check if desktop environment is available
Skip installing the UI client if SKIP_UI_APP is set to true

added tests for Ubuntu and macOS tests
2023-04-04 00:35:54 +02:00
Ruakij
5993982cca Add disable letsencrypt (#747)
Add NETBIRD_DISABLE_LETSENCRYPT support to explicit disable let's encrypt

Organize the setup.env.example variables into sections

Add traefik example
2023-04-04 00:21:40 +02:00
braginini
a42f7d2c3b Fix iface test 2023-04-03 18:34:28 +02:00
braginini
e2a3fc7558 Fix engine tests 2023-04-03 18:27:33 +02:00
braginini
5d191a8b9d Use custom stdnet 2023-04-03 17:37:05 +02:00
Zoltan Papp
86f9051a30 Fix/connection listener (#777)
Fix add/remove connection listener

In case we call the RemoveConnListener from Java then
we lose the reference from the original instance
2023-04-03 16:59:13 +02:00
Pascal Fischer
489892553a use UTC everywhere in server 2023-04-03 15:09:35 +02:00
Pascal Fischer
b05e30ac5a do not use UTC for time to stay consistent 2023-04-03 12:44:55 +02:00
pascal-fischer
769388cd21 Merge pull request #776 from netbirdio/feature/activity_events_for_pat
feature/activity_events_for_pat
2023-04-03 12:27:51 +02:00
pascal-fischer
c54fb9643c Merge pull request #774 from netbirdio/feature/add_pat_middleware
Feature/add pat middleware
2023-04-03 12:09:11 +02:00
Givi Khojanashvili
5dc0ff42a5 Fix broken auto-generated Rego rule (#769)
Default Rego policy generated from the rules in some cases is broken.
This change fixes the Rego template for rules to generate policies.

Also, file store load constantly regenerates policy objects from rules.
It allows updating/fixing of the default Rego template during releases.
2023-04-01 12:02:08 +02:00
Pascal Fischer
45badd2c39 add event store to user tests 2023-04-01 11:11:30 +02:00
Pascal Fischer
d3de035961 error responses always lower case + duplicate error response fix 2023-04-01 11:04:21 +02:00
braginini
79a8109d5e Extend stdnet to have an interface filter 2023-03-31 18:25:23 +02:00
Pascal Fischer
b2da0ae70f add activity events on PAT creation and deletion 2023-03-31 17:41:22 +02:00
Pascal Fischer
931c20c8fe fix test name 2023-03-31 12:45:10 +02:00
Pascal Fischer
2eaf4aa8d7 add test for auth middleware 2023-03-31 12:44:22 +02:00
Pascal Fischer
110067c00f change order for access control checks and aquire account lock after global lock 2023-03-31 12:03:53 +02:00
Pascal Fischer
32c96c15b8 disable linter errors by comment 2023-03-31 10:30:05 +02:00
Pascal Fischer
ca1dc5ac88 disable access control for token endpoint 2023-03-30 19:03:44 +02:00
Pascal Fischer
ce775d59ae revert codacy 2023-03-30 18:59:35 +02:00
Pascal Fischer
f273fe9f51 revert codacy 2023-03-30 18:54:55 +02:00
Pascal Fischer
e08af7fcdf codacy 2023-03-30 17:46:21 +02:00
Pascal Fischer
454240ca05 comments for codacy 2023-03-30 17:32:44 +02:00
Pascal Fischer
1343a3f00e add test + codacy 2023-03-30 16:43:39 +02:00
Pascal Fischer
2a79995706 fix linter 2023-03-30 16:22:15 +02:00
Pascal Fischer
e869882da1 fix merge 2023-03-30 16:14:51 +02:00
Pascal Fischer
6c8bb60632 fix merge 2023-03-30 16:06:46 +02:00
Pascal Fischer
4d7029d80c Merge branch 'main' into feature/add_pat_middleware
# Conflicts:
#	management/server/grpcserver.go
#	management/server/http/middleware/jwt.go
2023-03-30 16:06:21 +02:00
pascal-fischer
909f305728 Merge pull request #766 from netbirdio/feature/add_rest_endpoints_for_pat
Feature/add rest endpoints for pat
2023-03-30 15:55:48 +02:00
Pascal Fischer
5e2f66d591 fix codacy 2023-03-30 15:23:24 +02:00
braginini
ea44c1b723 Distinguish UDPMux and UDPMuxUniversal 2023-03-30 15:13:14 +02:00
Pascal Fischer
a7519859bc fix test 2023-03-30 14:15:44 +02:00
Pascal Fischer
9b000b89d5 Merge branch 'feature/add_rest_endpoints_for_pat' into feature/add_pat_middleware
# Conflicts:
#	management/server/mock_server/account_mock.go
2023-03-30 14:02:58 +02:00
Pascal Fischer
5c1acdbf2f move validation into account manager + func for get requests 2023-03-30 13:58:44 +02:00
Pascal Fischer
db3a9f0aa2 refactor jwt token validation and add PAT to middleware auth 2023-03-30 10:54:09 +02:00
Pascal Fischer
ecc4f8a10d fix Pat handler test 2023-03-29 19:13:01 +02:00
Pascal Fischer
03abdfa112 return empty object on all handlers instead of empty string 2023-03-29 18:46:40 +02:00
Pascal Fischer
9746a7f61a remove debug logs 2023-03-29 18:27:01 +02:00
Pascal Fischer
4ec6d5d20b remove debug logs 2023-03-29 18:23:10 +02:00
Pascal Fischer
3bab745142 last_used can be nil 2023-03-29 17:46:09 +02:00
Pascal Fischer
0ca3d27a80 update account mock 2023-03-29 15:25:44 +02:00
Pascal Fischer
c5942e6b33 store hashed token base64 encoded 2023-03-29 15:21:53 +02:00
Pascal Fischer
726ffb5740 add comments for exported functions 2023-03-29 15:06:54 +02:00
braginini
430f92094e Upgrade WireGuard 2023-03-29 13:58:42 +02:00
Maycon Santos
dfb7960cd4 Fix pre-shared key query name for android configuration (#773) 2023-03-29 10:41:14 +02:00
Zoltan Papp
ab0cf1b8aa Fix slice bounds out of range in msg decryption (#768) 2023-03-29 10:40:31 +02:00
Zoltan Papp
8ebd6ce963 Add OnDisconnecting service callback (#767)
Add OnDisconnecting service callback for mobile
2023-03-29 10:39:54 +02:00
Pascal Fischer
42ba0765c8 fix linter 2023-03-28 14:54:06 +02:00
Pascal Fischer
514403db37 use object instead of plain token for create response + handler test 2023-03-28 14:47:15 +02:00
Zoltan Papp
488d338ce8 Refactor the authentication part of mobile exports (#759)
Refactor the auth code into async calls for mobile framework

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2023-03-28 09:57:23 +02:00
braginini
d6c2b46019 Make it work 2023-03-27 19:09:01 +02:00
Pascal Fischer
6a75ec4ab7 fix http error codes 2023-03-27 17:42:05 +02:00
Pascal Fischer
b66e984ddd set limits for expiration 2023-03-27 17:28:24 +02:00
Pascal Fischer
c65a934107 refactor to use name instead of description 2023-03-27 16:28:49 +02:00
braginini
2e7d199a6d Merge remote-tracking branch 'origin/main' into feature/bind 2023-03-27 16:01:13 +02:00
Zoltan Papp
55ebf93815 Fix nil pointer exception when create config (#765)
The config stored in a wrong variable when has been generated a
new config
2023-03-27 15:37:58 +02:00
Pascal Fischer
9e74f30d2f fix delete token parameter lookup 2023-03-27 15:19:19 +02:00
braginini
8ac7eaf833 Bind init 2023-03-27 14:57:47 +02:00
Zoltan Papp
71d24e59e6 Add fqdn and address for notification listener (#757)
Extend the status notification listeners with FQDN and address
changes. It is required for mobile services.
2023-03-24 18:51:35 +01:00
Zoltan Papp
992cfe64e1 Add ipv6 test for stdnet pkg (#761) 2023-03-24 10:46:40 +01:00
Zoltan Papp
d1703479ff Add custom ice stdnet implementation (#754)
On Android, because of the hard SELinux policies can not list the
interfaces of the ICE package. Without it can not generate a host type
candidate. In this pull request, the list of interfaces comes via the Java
interface.
2023-03-24 08:40:39 +01:00
Pascal Fischer
de8608f99f add rest endpoints and update openapi doc 2023-03-21 16:02:19 +01:00
129 changed files with 5161 additions and 1178 deletions

View File

@@ -6,6 +6,10 @@ on:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: macos-latest

View File

@@ -6,6 +6,10 @@ on:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
strategy:
@@ -66,13 +70,13 @@ jobs:
run: go mod tidy
- name: Generate Iface Test bin
run: go test -c -o iface-testing.bin ./iface/...
run: go test -c -o iface-testing.bin ./iface/
- name: Generate RouteManager Test bin
run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
- name: Generate Engine Test bin
run: go test -c -o engine-testing.bin ./client/internal/*.go
run: go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: go test -c -o peer-testing.bin ./client/internal/peer/...
@@ -89,4 +93,4 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -6,47 +6,45 @@ on:
- main
pull_request:
env:
downloadPath: '${{ github.workspace }}\temp'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
pre:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- run: bash -x wireguard_nt.sh
working-directory: client
- uses: actions/upload-artifact@v2
with:
name: syso
path: client/*.syso
retention-days: 1
test:
needs: pre
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v2
uses: actions/setup-go@v4
id: go
with:
go-version: 1.19.x
- uses: actions/cache@v2
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with:
path: |
%LocalAppData%\go-build
~\go\pkg\mod
~\AppData\Local\go-build
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- uses: actions/download-artifact@v2
with:
name: syso
path: iface\
- name: Decompressing wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- name: Test
run: go test -tags=load_wgnt_from_rsrc -timeout 5m -p 1 ./...
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
- run: choco install -y sysinternals
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -1,5 +1,8 @@
name: golangci-lint
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
golangci:
name: lint

View File

@@ -0,0 +1,60 @@
name: Test installation Darwin
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
env:
SKIP_UI_APP: true
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
install-all:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename brew package
if: ${{ matrix.check_bin_install }}
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
echo "Error: NetBird UI is not installed"
exit 1
fi

View File

@@ -0,0 +1,38 @@
name: Test installation Linux
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
install-cli-only:
runs-on: ubuntu-latest
strategy:
matrix:
check_bin_install: [true, false]
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Rename apt package
if: ${{ matrix.check_bin_install }}
run: |
sudo mv /usr/bin/apt /usr/bin/apt.bak
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
- name: Run install script
run: |
sh ./release_files/install.sh
- name: Run tests
run: |
if ! command -v netbird &> /dev/null; then
echo "Error: netbird is not installed"
exit 1
fi

View File

@@ -9,9 +9,13 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.5"
SIGN_PIPE_VER: "v0.0.6"
GORELEASER_VER: "v1.14.1"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
release:
runs-on: ubuntu-latest
@@ -21,10 +25,6 @@ jobs:
uses: actions/checkout@v2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Generate syso with DLL
run: bash -x wireguard_nt.sh
working-directory: client
-
name: Set up Go
uses: actions/setup-go@v2
@@ -59,6 +59,17 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc amd64
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2

View File

@@ -6,6 +6,10 @@ on:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: ubuntu-latest
@@ -59,6 +63,11 @@ jobs:
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
CI_NETBIRD_TOKEN_SOURCE: "idToken"
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
run: |
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
@@ -68,6 +77,12 @@ jobs:
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
grep UseIDToken management.json | grep false
- name: run docker compose up
working-directory: infrastructure_files

View File

@@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
@@ -23,6 +24,11 @@ type TunAdapter interface {
iface.TunAdapter
}
// IFaceDiscover export internal IFaceDiscover for mobile
type IFaceDiscover interface {
stdnet.ExternalIFaceDiscover
}
func init() {
formatter.SetLogcatFormatter(log.StandardLogger())
}
@@ -31,6 +37,7 @@ func init() {
type Client struct {
cfgFile string
tunAdapter iface.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
@@ -38,7 +45,7 @@ type Client struct {
}
// NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client {
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client {
lvl, _ := log.ParseLevel("trace")
log.SetLevel(lvl)
@@ -46,6 +53,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client {
cfgFile: cfgFile,
deviceName: deviceName,
tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
}
@@ -70,14 +78,14 @@ func (c *Client) Run(urlOpener URLOpener) error {
c.ctxCancelLock.Unlock()
auth := NewAuthWithConfig(ctx, cfg)
err = auth.Login(urlOpener)
err = auth.login(urlOpener)
if err != nil {
return err
}
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter)
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover)
}
// Stop the internal client and free the resources
@@ -110,12 +118,12 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos}
}
// AddConnectionListener add new network connection listener
func (c *Client) AddConnectionListener(listener ConnectionListener) {
c.recorder.AddConnectionListener(listener)
// SetConnectionListener set the network connection listener
func (c *Client) SetConnectionListener(listener ConnectionListener) {
c.recorder.SetConnectionListener(listener)
}
// RemoveConnectionListener remove connection listener
func (c *Client) RemoveConnectionListener(listener ConnectionListener) {
c.recorder.RemoveConnectionListener(listener)
func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}

View File

@@ -17,6 +17,18 @@ import (
"github.com/netbirdio/netbird/client/internal"
)
// SSOListener is async listener for mobile framework
type SSOListener interface {
OnSuccess(bool)
OnError(error)
}
// ErrListener is async listener for mobile framework
type ErrListener interface {
OnSuccess()
OnError(error)
}
// URLOpener it is a callback interface. The Open function will be triggered if
// the backend want to show an url for the user
type URLOpener interface {
@@ -59,7 +71,18 @@ func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
// is not supported and returns false without saving the configuration. For other errors return false.
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
go func() {
sso, err := a.saveConfigIfSSOSupported()
if err != nil {
listener.OnError(err)
} else {
listener.OnSuccess(sso)
}
}()
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
@@ -83,7 +106,18 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
}
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
go func() {
err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
if err != nil {
resultListener.OnError(err)
} else {
resultListener.OnSuccess()
}
}()
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
@@ -103,7 +137,18 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
}
// Login try register the client on the server
func (a *Auth) Login(urlOpener URLOpener) error {
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
go func() {
err := a.login(urlOpener)
if err != nil {
resultListener.OnError(err)
} else {
resultListener.OnSuccess()
}
}()
}
func (a *Auth) login(urlOpener URLOpener) error {
var needsLogin bool
// check if we need to generate JWT token
@@ -121,7 +166,7 @@ func (a *Auth) Login(urlOpener URLOpener) error {
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.AccessToken
jwtToken = tokenInfo.GetTokenToUse()
}
err = a.withBackOff(a.ctx, func() error {
@@ -154,12 +199,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo,
}
}
hostedClient := internal.NewHostedDeviceFlow(
providerConfig.ProviderConfig.Audience,
providerConfig.ProviderConfig.ClientID,
providerConfig.ProviderConfig.TokenEndpoint,
providerConfig.ProviderConfig.DeviceAuthEndpoint,
)
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil {

View File

@@ -135,7 +135,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.AccessToken
jwtToken = tokenInfo.GetTokenToUse()
}
err = WithBackOff(func() error {
@@ -172,12 +172,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
}
}
hostedClient := internal.NewHostedDeviceFlow(
providerConfig.ProviderConfig.Audience,
providerConfig.ProviderConfig.ClientID,
providerConfig.ProviderConfig.TokenEndpoint,
providerConfig.ProviderConfig.DeviceAuthEndpoint,
)
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
if err != nil {

View File

@@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {

View File

@@ -193,6 +193,7 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe`
Sleep 3000
Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll"
RmDir /r "$INSTDIR"
SetShellVarContext current

View File

@@ -27,7 +27,7 @@ const (
)
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-"}
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
// ConfigInput carries configuration changes to the client
type ConfigInput struct {

View File

@@ -13,6 +13,7 @@ import (
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
@@ -22,7 +23,7 @@ import (
)
// RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter) error {
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error {
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@@ -58,9 +59,6 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return err
}
statusRecorder.MarkManagementDisconnected()
statusRecorder.ClientStart()
defer statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -81,12 +79,12 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
}
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
defer func() {
err = mgmClient.Close()
@@ -110,7 +108,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModuleIsLoaded(),
KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
@@ -146,7 +144,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter)
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -162,7 +160,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
statusRecorder.ClientStart()
<-engineCtx.Done()
statusRecorder.ClientTeardown()
backOff.Reset()
@@ -193,12 +194,13 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*EngineConfig, error) {
engineConf := &EngineConfig{
WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address,
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,

View File

@@ -34,6 +34,10 @@ type ProviderConfig struct {
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
@@ -91,9 +95,16 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
// keep compatibility with older management versions
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
@@ -116,5 +127,8 @@ func isProviderConfigValid(config ProviderConfig) error {
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
}
return nil
}

View File

@@ -9,7 +9,10 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
)
@@ -199,7 +202,11 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil)
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}

View File

@@ -18,8 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
@@ -49,6 +49,8 @@ type EngineConfig struct {
// TunAdapter is option. It is necessary for mobile version.
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
// WgAddr is a Wireguard local address (Netbird Network IP)
WgAddr string
@@ -163,60 +165,77 @@ func (e *Engine) Stop() error {
return nil
}
// Start creates a new Wireguard tunnel interface and listens to events from Signal and Management services
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start() error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
wgIfaceName := e.config.WgIfaceName
wgIFaceName := e.config.WgIfaceName
wgAddr := e.config.WgAddr
myPrivateKey := e.config.WgPrivateKey
var err error
e.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter)
transportNet, err := e.newStdNet()
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIfaceName, err.Error())
log.Warnf("failed to create pion's stdnet: %s", err)
}
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter, transportNet)
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
return err
}
networkName := "udp"
if e.config.DisableIPv6Discovery {
networkName = "udp4"
}
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
e.close()
return err
}
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
e.close()
return err
}
e.udpMux = ice.NewUDPMuxDefault(ice.UDPMuxParams{UDPConn: e.udpMuxConn})
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx})
err = e.wgInterface.Create()
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error())
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
e.close()
return err
}
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error())
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
e.close()
return err
}
if e.wgInterface.IsUserspaceBind() {
iceBind := e.wgInterface.GetBind()
udpMux, err := iceBind.GetICEMux()
if err != nil {
e.close()
return err
}
e.udpMux = udpMux.UDPMuxDefault
e.udpMuxSrflx = udpMux
log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String())
} else {
networkName := "udp"
if e.config.DisableIPv6Discovery {
networkName = "udp4"
}
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
e.close()
return err
}
udpMuxParams := ice.UDPMuxParams{
UDPConn: e.udpMuxConn,
Net: transportNet,
}
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
if err != nil {
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error())
e.close()
return err
}
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
}
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
if e.dnsServer == nil {
@@ -485,7 +504,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
IP: e.config.WgAddr,
PubKey: e.config.WgPrivateKey.PublicKey().String(),
KernelInterface: iface.WireguardModuleIsLoaded(),
KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: conf.GetFqdn(),
})
@@ -789,14 +808,6 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
proxyConfig := proxy.Config{
RemoteKey: pubKey,
WgListenAddr: fmt.Sprintf("127.0.0.1:%d", e.config.WgPort),
WgInterface: e.wgInterface,
AllowedIps: allowedIPs,
PreSharedKey: e.config.PreSharedKey,
}
// randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{
@@ -808,12 +819,12 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
Timeout: timeout,
UDPMux: e.udpMux,
UDPMuxSrflx: e.udpMuxSrflx,
ProxyConfig: proxyConfig,
LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(),
AllowedIPs: allowedIPs,
}
peerConn, err := peer.NewConn(config, e.statusRecorder)
peerConn, err := peer.NewConn(config, e.wgInterface, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover)
if err != nil {
return nil, err
}
@@ -995,12 +1006,6 @@ func (e *Engine) close() {
}
}
if e.udpMuxSrflx != nil {
if err := e.udpMuxSrflx.Close(); err != nil {
log.Debugf("close server reflexive udp mux: %v", err)
}
}
if e.udpMuxConn != nil {
if err := e.udpMuxConn.Close(); err != nil {
log.Debugf("close udp mux connection: %v", err)

View File

@@ -0,0 +1,11 @@
//go:build !android
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(e.config.IFaceBlackList)
}

View File

@@ -0,0 +1,7 @@
package internal
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(e.config.IFaceDiscover, e.config.IFaceBlackList)
}

View File

@@ -3,6 +3,8 @@ package internal
import (
"context"
"fmt"
"github.com/netbirdio/netbird/iface/bind"
"github.com/pion/transport/v2/stdnet"
"net"
"net/netip"
"os"
@@ -207,11 +209,23 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
conn, err := net.ListenUDP("udp4", nil)
if err != nil {
t.Fatal(err)
}
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
type testCase struct {
name string
@@ -549,7 +563,11 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
@@ -714,7 +732,11 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{

View File

@@ -35,15 +35,6 @@ type DeviceAuthInfo struct {
Interval int `json:"interval"`
}
// TokenInfo holds information of issued access token
type TokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
// HostedGrantType grant type for device flow on Hosted
const (
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
@@ -52,16 +43,7 @@ const (
// Hosted client
type Hosted struct {
// Hosted API Audience for validation
Audience string
// Hosted Native application client id
ClientID string
// Hosted Native application request scope
Scope string
// TokenEndpoint to request access token
TokenEndpoint string
// DeviceAuthEndpoint to request device authorization code
DeviceAuthEndpoint string
providerConfig ProviderConfig
HTTPClient HTTPClient
}
@@ -70,7 +52,7 @@ type Hosted struct {
type RequestDeviceCodePayload struct {
Audience string `json:"audience"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
Scope string `json:"scope"`
}
// TokenRequestPayload used for requesting the auth0 token
@@ -93,8 +75,26 @@ type Claims struct {
Audience interface{} `json:"aud"`
}
// TokenInfo holds information of issued access token
type TokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
UseIDToken bool `json:"-"`
}
// GetTokenToUse returns either the access or id token based on UseIDToken field
func (t TokenInfo) GetTokenToUse() string {
if t.UseIDToken {
return t.IDToken
}
return t.AccessToken
}
// NewHostedDeviceFlow returns an Hosted OAuth client
func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted {
func NewHostedDeviceFlow(config ProviderConfig) *Hosted {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -104,27 +104,23 @@ func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string,
}
return &Hosted{
Audience: audience,
ClientID: clientID,
Scope: "openid",
TokenEndpoint: tokenEndpoint,
HTTPClient: httpClient,
DeviceAuthEndpoint: deviceAuthEndpoint,
providerConfig: config,
HTTPClient: httpClient,
}
}
// GetClientID returns the provider client id
func (h *Hosted) GetClientID(ctx context.Context) string {
return h.ClientID
return h.providerConfig.ClientID
}
// RequestDeviceCode requests a device code login flow information from Hosted
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
form := url.Values{}
form.Add("client_id", h.ClientID)
form.Add("audience", h.Audience)
form.Add("scope", h.Scope)
req, err := http.NewRequest("POST", h.DeviceAuthEndpoint,
form.Add("client_id", h.providerConfig.ClientID)
form.Add("audience", h.providerConfig.Audience)
form.Add("scope", h.providerConfig.Scope)
req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint,
strings.NewReader(form.Encode()))
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
@@ -157,10 +153,10 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", h.ClientID)
form.Add("client_id", h.providerConfig.ClientID)
form.Add("grant_type", HostedGrantType)
form.Add("device_code", info.DeviceCode)
req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode()))
req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
}
@@ -225,18 +221,20 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
}
err = isValidAccessToken(tokenResponse.AccessToken, h.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
tokenInfo := TokenInfo{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
IDToken: tokenResponse.IDToken,
ExpiresIn: tokenResponse.ExpiresIn,
UseIDToken: h.providerConfig.UseIDToken,
}
err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
return tokenInfo, err
}
}

View File

@@ -3,14 +3,15 @@ package internal
import (
"context"
"fmt"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
)
type mockHTTPClient struct {
@@ -113,12 +114,15 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
}
hosted := Hosted{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
HTTPClient: &httpClient,
providerConfig: ProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
},
HTTPClient: &httpClient,
}
authInfo, err := hosted.RequestDeviceCode(context.TODO())
@@ -275,12 +279,15 @@ func TestHosted_WaitToken(t *testing.T) {
}
hosted := Hosted{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
HTTPClient: &httpClient,
}
providerConfig: ProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
},
HTTPClient: &httpClient}
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel()

View File

@@ -9,15 +9,14 @@ import (
"time"
"github.com/pion/ice/v2"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/version"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/version"
)
// ConnConfig is a peer Connection configuration
@@ -38,14 +37,14 @@ type ConnConfig struct {
Timeout time.Duration
ProxyConfig proxy.Config
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
LocalWgPort int
NATExternalIPs []string
AllowedIPs string
}
// OfferAnswer represents a session establishment offer or answer
@@ -93,6 +92,10 @@ type Conn struct {
proxy proxy.Proxy
remoteModeCh chan ModeMessage
meta meta
wgIface *iface.WGIface
adapter iface.TunAdapter
iFaceDiscover stdnet.ExternalIFaceDiscover
}
// meta holds meta information about a connection
@@ -118,7 +121,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status) (*Conn, error) {
func NewConn(config ConnConfig, wgIface *iface.WGIface, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
return &Conn{
config: config,
mu: sync.Mutex{},
@@ -128,41 +131,20 @@ func NewConn(config ConnConfig, statusRecorder *Status) (*Conn, error) {
remoteAnswerCh: make(chan OfferAnswer),
statusRecorder: statusRecorder,
remoteModeCh: make(chan ModeMessage, 1),
adapter: adapter,
iFaceDiscover: iFaceDiscover,
wgIface: wgIface,
}, nil
}
// interfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
// to avoid building tunnel over them
func interfaceFilter(blackList []string) func(string) bool {
return func(iFace string) bool {
for _, s := range blackList {
if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace)
return false
}
}
// look for unlisted WireGuard interfaces
wg, err := wgctrl.New()
if err != nil {
log.Debugf("trying to create a wgctrl client failed with: %v", err)
return true
}
defer func() {
_ = wg.Close()
}()
_, err = wg.Device(iFace)
return err != nil
}
}
func (conn *Conn) reCreateAgent() error {
conn.mu.Lock()
defer conn.mu.Unlock()
failedTimeout := 6 * time.Second
transportNet, err := stdnet.NewNet()
var err error
transportNet, err := conn.newStdNet()
if err != nil {
log.Warnf("failed to create pion's stdnet: %s", err)
}
@@ -172,7 +154,7 @@ func (conn *Conn) reCreateAgent() error {
Urls: conn.config.StunTurn,
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay},
FailedTimeout: &failedTimeout,
InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList),
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs,
@@ -215,7 +197,7 @@ func (conn *Conn) Open() error {
peerState := State{PubKey: conn.config.Key}
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
peerState.IP = strings.Split(conn.config.AllowedIPs, "/")[0]
peerState.ConnStatusUpdate = time.Now()
peerState.ConnStatus = conn.status
@@ -312,7 +294,7 @@ func (conn *Conn) Open() error {
return err
}
if conn.proxy.Type() == proxy.TypeNoProxy {
if conn.proxy.Type() == proxy.TypeDirectNoProxy {
host, _, _ := net.SplitHostPort(remoteConn.LocalAddr().String())
rhost, _, _ := net.SplitHostPort(remoteConn.RemoteAddr().String())
// direct Wireguard connection
@@ -334,14 +316,21 @@ func (conn *Conn) Open() error {
// useProxy determines whether a direct connection (without a go proxy) is possible
//
// There are 2 cases:
// There are 3 cases:
//
// * When neither candidate is from hard nat and one of the peers has a public IP
//
// * both peers are in the same private network
//
// * Local peer uses userspace interface with bind.ICEBind and is not relayed
//
// Please note, that this check happens when peers were already able to ping each other using ICE layer.
func shouldUseProxy(pair *ice.CandidatePair) bool {
func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool {
if !isRelayCandidate(pair.Local) && userspaceBind {
return false
}
if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) {
return false
}
@@ -350,13 +339,37 @@ func shouldUseProxy(pair *ice.CandidatePair) bool {
return false
}
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) {
if isHostCandidateWithPrivateIP(pair.Local) && isHostCandidateWithPrivateIP(pair.Remote) && isSameNetworkPrefix(pair) {
return false
}
return true
}
func isSameNetworkPrefix(pair *ice.CandidatePair) bool {
localIPStr, _, err := net.SplitHostPort(pair.Local.Address())
if err != nil {
return false
}
remoteIPStr, _, err := net.SplitHostPort(pair.Remote.Address())
if err != nil {
return false
}
localIP := net.ParseIP(localIPStr)
remoteIP := net.ParseIP(remoteIPStr)
if localIP == nil || remoteIP == nil {
return false
}
// only consider /16 networks
mask := net.IPMask{255, 255, 0, 0}
return localIP.Mask(mask).Equal(remoteIP.Mask(mask))
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func isHardNATCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay || candidate.Type() == ice.CandidateTypePeerReflexive
}
@@ -405,7 +418,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
}
peerState.Direct = p.Type() == proxy.TypeNoProxy
peerState.Direct = p.Type() == proxy.TypeDirectNoProxy
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -416,8 +429,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
}
func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy {
useProxy := shouldUseProxy(pair)
useProxy := shouldUseProxy(pair, conn.wgIface.IsUserspaceBind())
localDirectMode := !useProxy
remoteDirectMode := localDirectMode
@@ -427,13 +439,17 @@ func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgP
remoteDirectMode = conn.receiveRemoteDirectMode()
}
if conn.wgIface.IsUserspaceBind() && localDirectMode {
return proxy.NewNoProxy(conn.config.ProxyConfig)
}
if localDirectMode && remoteDirectMode {
log.Debugf("using WireGuard direct mode with peer %s", conn.config.Key)
return proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
//wgInterface *iface.WGIface, remoteKey string, allowedIps string, preSharedKey *wgtypes.Key, remoteWgPort int)
return proxy.NewDirectNoProxy(conn.wgIface, conn.config.Key, conn.config.AllowedIPs, remoteWgPort)
}
log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key)
return proxy.NewWireguardProxy(conn.config.ProxyConfig)
return proxy.NewWireGuardProxy(conn.config.ProxyConfig)
}
func (conn *Conn) sendLocalDirectMode(localMode bool) {

View File

@@ -1,6 +1,7 @@
package peer
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"sync"
"testing"
"time"
@@ -28,7 +29,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
ignore := []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"}
filter := interfaceFilter(ignore)
filter := stdnet.InterfaceFilter(ignore)
for _, s := range ignore {
assert.Equal(t, filter(s), false)
@@ -37,7 +38,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
conn, err := NewConn(connConf, nil)
conn, err := NewConn(connConf, nil, nil, nil)
if err != nil {
return
}
@@ -49,7 +50,7 @@ func TestConn_GetKey(t *testing.T) {
func TestConn_OnRemoteOffer(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
if err != nil {
return
}
@@ -83,7 +84,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
func TestConn_OnRemoteAnswer(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
if err != nil {
return
}
@@ -116,7 +117,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
}
func TestConn_Status(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
if err != nil {
return
}
@@ -143,7 +144,7 @@ func TestConn_Status(t *testing.T) {
func TestConn_Close(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
if err != nil {
return
}
@@ -202,12 +203,13 @@ func TestConn_ShouldUseProxy(t *testing.T) {
}
privateHostCandidate := &mockICECandidate{
AddressFunc: func() string {
return "10.0.0.1"
return "10.0.0.1:44576"
},
TypeFunc: func() ice.CandidateType {
return ice.CandidateTypeHost
},
}
srflxCandidate := &mockICECandidate{
AddressFunc: func() string {
return "1.1.1.1"
@@ -324,7 +326,7 @@ func TestConn_ShouldUseProxy(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := shouldUseProxy(testCase.candatePair)
result := shouldUseProxy(testCase.candatePair, false)
if result != testCase.expected {
t.Errorf("got a different result. Expected %t Got %t", testCase.expected, result)
}
@@ -365,7 +367,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: true,
inputRemoteModeMessage: true,
expected: proxy.TypeWireguard,
expected: proxy.TypeWireGuard,
},
{
name: "Should Result In Using Wireguard Proxy When Remote Eval Is Use Proxy",
@@ -375,7 +377,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: true,
inputRemoteModeMessage: false,
expected: proxy.TypeWireguard,
expected: proxy.TypeWireGuard,
},
{
name: "Should Result In Using Wireguard Proxy When Remote Direct Mode Support Is False And Local Eval Is Use Proxy",
@@ -385,7 +387,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: false,
inputRemoteModeMessage: false,
expected: proxy.TypeWireguard,
expected: proxy.TypeWireGuard,
},
{
name: "Should Result In Using Direct When Remote Direct Mode Support Is False And Local Eval Is No Use Proxy",
@@ -395,7 +397,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: false,
inputRemoteModeMessage: false,
expected: proxy.TypeNoProxy,
expected: proxy.TypeDirectNoProxy,
},
{
name: "Should Result In Using Direct When Local And Remote Eval Is No Proxy",
@@ -405,13 +407,13 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
},
inputDirectModeSupport: true,
inputRemoteModeMessage: true,
expected: proxy.TypeNoProxy,
expected: proxy.TypeDirectNoProxy,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
g := errgroup.Group{}
conn, err := NewConn(connConf, nil)
conn, err := NewConn(connConf, nil, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -5,5 +5,7 @@ type Listener interface {
OnConnected()
OnDisconnected()
OnConnecting()
OnDisconnecting()
OnAddressChanged(string, string)
OnPeersListChanged(int)
}

View File

@@ -8,117 +8,135 @@ const (
stateDisconnected = iota
stateConnected
stateConnecting
stateDisconnecting
)
type notifier struct {
serverStateLock sync.Mutex
listenersLock sync.Mutex
listeners map[Listener]struct{}
currentServerState bool
listener Listener
currentClientState bool
lastNotification int
}
func newNotifier() *notifier {
return &notifier{
listeners: make(map[Listener]struct{}),
}
return &notifier{}
}
func (n *notifier) addListener(listener Listener) {
func (n *notifier) setListener(listener Listener) {
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
n.serverStateLock.Lock()
go n.notifyListener(listener, n.lastNotification)
n.notifyListener(listener, n.lastNotification)
n.serverStateLock.Unlock()
n.listeners[listener] = struct{}{}
n.listener = listener
}
func (n *notifier) removeListener(listener Listener) {
func (n *notifier) removeListener() {
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
delete(n.listeners, listener)
n.listener = nil
}
func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
var newState bool
if mgmState && signalState {
newState = true
} else {
newState = false
}
calculatedState := n.calculateState(mgmState, signalState)
if !n.isServerStateChanged(newState) {
if !n.isServerStateChanged(calculatedState) {
return
}
n.currentServerState = newState
n.lastNotification = n.calculateState(newState, n.currentClientState)
n.lastNotification = calculatedState
go n.notifyAll(n.lastNotification)
n.notify(n.lastNotification)
}
func (n *notifier) clientStart() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = true
n.lastNotification = n.calculateState(n.currentServerState, true)
go n.notifyAll(n.lastNotification)
n.lastNotification = stateConnected
n.notify(n.lastNotification)
}
func (n *notifier) clientStop() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = false
n.lastNotification = n.calculateState(n.currentServerState, false)
go n.notifyAll(n.lastNotification)
n.lastNotification = stateDisconnected
n.notify(n.lastNotification)
}
func (n *notifier) isServerStateChanged(newState bool) bool {
return n.currentServerState != newState
func (n *notifier) clientTearDown() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = false
n.lastNotification = stateDisconnecting
n.notify(n.lastNotification)
}
func (n *notifier) notifyAll(state int) {
func (n *notifier) isServerStateChanged(newState int) bool {
return n.lastNotification != newState
}
func (n *notifier) notify(state int) {
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
for l := range n.listeners {
n.notifyListener(l, state)
if n.listener == nil {
return
}
n.notifyListener(n.listener, state)
}
func (n *notifier) notifyListener(l Listener, state int) {
switch state {
case stateDisconnected:
l.OnDisconnected()
case stateConnected:
l.OnConnected()
case stateConnecting:
l.OnConnecting()
}
go func() {
switch state {
case stateDisconnected:
l.OnDisconnected()
case stateConnected:
l.OnConnected()
case stateConnecting:
l.OnConnecting()
case stateDisconnecting:
l.OnDisconnecting()
}
}()
}
func (n *notifier) calculateState(serverState bool, clientState bool) int {
if serverState && clientState {
func (n *notifier) calculateState(managementConn, signalConn bool) int {
if managementConn && signalConn {
return stateConnected
}
if !clientState {
if !managementConn && !signalConn {
return stateDisconnected
}
if n.lastNotification == stateDisconnecting {
return stateDisconnecting
}
return stateConnecting
}
func (n *notifier) peerListChanged(numOfPeers int) {
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
for l := range n.listeners {
l.OnPeersListChanged(numOfPeers)
if n.listener == nil {
return
}
n.listener.OnPeersListChanged(numOfPeers)
}
func (n *notifier) localAddressChanged(fqdn, address string) {
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
if n.listener == nil {
return
}
n.listener.OnAddressChanged(fqdn, address)
}

View File

@@ -1,32 +1,97 @@
package peer
import (
"sync"
"testing"
)
type mocListener struct {
lastState int
wg sync.WaitGroup
peers int
}
func (l *mocListener) OnConnected() {
l.lastState = stateConnected
l.wg.Done()
}
func (l *mocListener) OnDisconnected() {
l.lastState = stateDisconnected
l.wg.Done()
}
func (l *mocListener) OnConnecting() {
l.lastState = stateConnecting
l.wg.Done()
}
func (l *mocListener) OnDisconnecting() {
l.lastState = stateDisconnecting
l.wg.Done()
}
func (l *mocListener) OnAddressChanged(host, addr string) {
}
func (l *mocListener) OnPeersListChanged(size int) {
l.peers = size
}
func (l *mocListener) setWaiter() {
l.wg.Add(1)
}
func (l *mocListener) wait() {
l.wg.Wait()
}
func Test_notifier_serverState(t *testing.T) {
type scenario struct {
name string
expected bool
expected int
mgmState bool
signalState bool
}
scenarios := []scenario{
{"connected", true, true, true},
{"mgm down", false, false, true},
{"signal down", false, true, false},
{"disconnected", false, false, false},
{"connected", stateConnected, true, true},
{"mgm down", stateConnecting, false, true},
{"signal down", stateConnecting, true, false},
{"disconnected", stateDisconnected, false, false},
}
for _, tt := range scenarios {
t.Run(tt.name, func(t *testing.T) {
n := newNotifier()
n.updateServerStates(tt.mgmState, tt.signalState)
if n.currentServerState != tt.expected {
t.Errorf("invalid serverstate: %t, expected: %t", n.currentServerState, tt.expected)
if n.lastNotification != tt.expected {
t.Errorf("invalid serverstate: %d, expected: %d", n.lastNotification, tt.expected)
}
})
}
}
func Test_notifier_SetListener(t *testing.T) {
listener := &mocListener{}
listener.setWaiter()
n := newNotifier()
n.lastNotification = stateConnecting
n.setListener(listener)
listener.wait()
if listener.lastState != n.lastNotification {
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
}
}
func Test_notifier_RemoveListener(t *testing.T) {
listener := &mocListener{}
listener.setWaiter()
n := newNotifier()
n.lastNotification = stateConnecting
n.setListener(listener)
n.removeListener()
n.peerListChanged(1)
if listener.peers != 0 {
t.Errorf("invalid state: %d", listener.peers)
}
}

View File

@@ -190,6 +190,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
defer d.mux.Unlock()
d.localPeer = localPeerState
d.notifyAddressChanged()
}
// CleanLocalPeerState cleans local peer status
@@ -198,6 +199,7 @@ func (d *Status) CleanLocalPeerState() {
defer d.mux.Unlock()
d.localPeer = LocalPeerState{}
d.notifyAddressChanged()
}
// MarkManagementDisconnected sets ManagementState to disconnected
@@ -215,7 +217,7 @@ func (d *Status) MarkManagementConnected() {
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.managementState = true
d.managementState = true
}
// UpdateSignalAddress update the address of the signal server
@@ -238,7 +240,7 @@ func (d *Status) MarkSignalDisconnected() {
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.signalState = false
d.signalState = false
}
// MarkSignalConnected sets SignalState to connected
@@ -286,14 +288,19 @@ func (d *Status) ClientStop() {
d.notifier.clientStop()
}
// AddConnectionListener add a listener to the notifier
func (d *Status) AddConnectionListener(listener Listener) {
d.notifier.addListener(listener)
// ClientTeardown will notify all listeners about the service is under teardown
func (d *Status) ClientTeardown() {
d.notifier.clientTearDown()
}
// RemoveConnectionListener remove a listener from the notifier
func (d *Status) RemoveConnectionListener(listener Listener) {
d.notifier.removeListener(listener)
// SetConnectionListener set a listener to the notifier
func (d *Status) SetConnectionListener(listener Listener) {
d.notifier.setListener(listener)
}
// RemoveConnectionListener remove the listener from the notifier
func (d *Status) RemoveConnectionListener() {
d.notifier.removeListener()
}
func (d *Status) onConnectionChanged() {
@@ -303,3 +310,7 @@ func (d *Status) onConnectionChanged() {
func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(len(d.peers))
}
func (d *Status) notifyAddressChanged() {
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
}

View File

@@ -0,0 +1,11 @@
//go:build !android
package peer
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(conn.config.InterfaceBlackList)
}

View File

@@ -0,0 +1,7 @@
package peer
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
}

View File

@@ -0,0 +1,67 @@
package proxy
import (
"net"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
)
// DirectNoProxy is used when there is no need for a proxy between ICE and WireGuard.
// This is possible in either of these cases:
// - peers are in the same local network
// - one of the peers has a public static IP (host)
// DirectNoProxy will just update remote peer with a remote host and fixed WireGuard port (r.g. 51820).
// In order DirectNoProxy to work, WireGuard port has to be fixed for the time being.
type DirectNoProxy struct {
wgInterface *iface.WGIface
remoteKey string
allowedIps string
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
remoteWgListenPort int
}
// NewDirectNoProxy creates a new DirectNoProxy with a provided config and remote peer's WireGuard listen port
func NewDirectNoProxy(wgInterface *iface.WGIface, remoteKey string, allowedIps string, remoteWgPort int) *DirectNoProxy {
return &DirectNoProxy{
wgInterface: wgInterface,
remoteKey: remoteKey,
allowedIps: allowedIps,
remoteWgListenPort: remoteWgPort}
}
// Close removes peer from the WireGuard interface
func (p *DirectNoProxy) Close() error {
err := p.wgInterface.RemovePeer(p.remoteKey)
if err != nil {
return err
}
return nil
}
// Start just updates WireGuard peer with the remote IP and default WireGuard port
func (p *DirectNoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using DirectNoProxy while connecting to peer %s", p.remoteKey)
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
addr.Port = p.remoteWgListenPort
err = p.wgInterface.UpdatePeer(p.remoteKey, p.allowedIps, addr)
if err != nil {
return err
}
return nil
}
// Type returns the type of this proxy
func (p *DirectNoProxy) Type() Type {
return TypeDirectNoProxy
}

View File

@@ -5,24 +5,18 @@ import (
"net"
)
// NoProxy is used when there is no need for a proxy between ICE and Wireguard.
// This is possible in either of these cases:
// - peers are in the same local network
// - one of the peers has a public static IP (host)
// NoProxy will just update remote peer with a remote host and fixed Wireguard port (r.g. 51820).
// In order NoProxy to work, Wireguard port has to be fixed for the time being.
// NoProxy is used just to configure WireGuard without any local proxy in between.
// Used when the WireGuard interface is userspace and uses bind.ICEBind
type NoProxy struct {
config Config
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
RemoteWgListenPort int
}
// NewNoProxy creates a new NoProxy with a provided config and remote peer's WireGuard listen port
func NewNoProxy(config Config, remoteWgPort int) *NoProxy {
return &NoProxy{config: config, RemoteWgListenPort: remoteWgPort}
// NewNoProxy creates a new NoProxy with a provided config
func NewNoProxy(config Config) *NoProxy {
return &NoProxy{config: config}
}
// Close removes peer from the WireGuard interface
func (p *NoProxy) Close() error {
err := p.config.WgInterface.RemovePeer(p.config.RemoteKey)
if err != nil {
@@ -31,23 +25,16 @@ func (p *NoProxy) Close() error {
return nil
}
// Start just updates Wireguard peer with the remote IP and default Wireguard port
// Start just updates WireGuard peer with the remote address
func (p *NoProxy) Start(remoteConn net.Conn) error {
log.Debugf("using NoProxy while connecting to peer %s", p.config.RemoteKey)
log.Debugf("using NoProxy to connect to peer %s at %s", p.config.RemoteKey, remoteConn.RemoteAddr().String())
addr, err := net.ResolveUDPAddr("udp", remoteConn.RemoteAddr().String())
if err != nil {
return err
}
addr.Port = p.RemoteWgListenPort
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
return p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)
if err != nil {
return err
}
return nil
}
func (p *NoProxy) Type() Type {

View File

@@ -1,31 +1,19 @@
package proxy
import (
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"io"
"net"
"time"
)
const DefaultWgKeepAlive = 25 * time.Second
type Type string
const (
TypeNoProxy Type = "NoProxy"
TypeWireguard Type = "Wireguard"
TypeDummy Type = "Dummy"
TypeDirectNoProxy Type = "DirectNoProxy"
TypeWireGuard Type = "WireGuard"
TypeDummy Type = "Dummy"
TypeNoProxy Type = "NoProxy"
)
type Config struct {
WgListenAddr string
RemoteKey string
WgInterface *iface.WGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
type Proxy interface {
io.Closer
// Start creates a local remoteConn and starts proxying data from/to remoteConn

View File

@@ -6,8 +6,8 @@ import (
"net"
)
// WireguardProxy proxies
type WireguardProxy struct {
// WireGuardProxy proxies
type WireGuardProxy struct {
ctx context.Context
cancel context.CancelFunc
@@ -17,13 +17,13 @@ type WireguardProxy struct {
localConn net.Conn
}
func NewWireguardProxy(config Config) *WireguardProxy {
p := &WireguardProxy{config: config}
func NewWireGuardProxy(config Config) *WireGuardProxy {
p := &WireGuardProxy{config: config}
p.ctx, p.cancel = context.WithCancel(context.Background())
return p
}
func (p *WireguardProxy) updateEndpoint() error {
func (p *WireGuardProxy) updateEndpoint() error {
udpAddr, err := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
if err != nil {
return err
@@ -38,7 +38,7 @@ func (p *WireguardProxy) updateEndpoint() error {
return nil
}
func (p *WireguardProxy) Start(remoteConn net.Conn) error {
func (p *WireGuardProxy) Start(remoteConn net.Conn) error {
p.remoteConn = remoteConn
var err error
@@ -60,7 +60,7 @@ func (p *WireguardProxy) Start(remoteConn net.Conn) error {
return nil
}
func (p *WireguardProxy) Close() error {
func (p *WireGuardProxy) Close() error {
p.cancel()
if c := p.localConn; c != nil {
err := p.localConn.Close()
@@ -77,7 +77,7 @@ func (p *WireguardProxy) Close() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WireguardProxy) proxyToRemote() {
func (p *WireGuardProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
@@ -101,7 +101,7 @@ func (p *WireguardProxy) proxyToRemote() {
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WireguardProxy) proxyToLocal() {
func (p *WireGuardProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
@@ -123,6 +123,6 @@ func (p *WireguardProxy) proxyToLocal() {
}
}
func (p *WireguardProxy) Type() Type {
return TypeWireguard
func (p *WireGuardProxy) Type() Type {
return TypeWireGuard
}

View File

@@ -3,6 +3,7 @@ package routemanager
import (
"context"
"fmt"
"github.com/pion/transport/v2/stdnet"
"net/netip"
"runtime"
"testing"
@@ -391,7 +392,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()

View File

@@ -3,6 +3,7 @@ package routemanager
import (
"fmt"
"github.com/netbirdio/netbird/iface"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require"
"net"
"net/netip"
@@ -32,7 +33,11 @@ func TestAddRemoveRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()

View File

@@ -0,0 +1,14 @@
package stdnet
import "github.com/pion/transport/v2"
// ExternalIFaceDiscover provide an option for external services (mobile)
// to collect network interface information
type ExternalIFaceDiscover interface {
// IFaces return with the description of the interfaces
IFaces() (string, error)
}
type iFaceDiscover interface {
iFaces() ([]*transport.Interface, error)
}

View File

@@ -0,0 +1,95 @@
package stdnet
import (
"fmt"
"net"
"strings"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
)
type mobileIFaceDiscover struct {
externalDiscover ExternalIFaceDiscover
}
func newMobileIFaceDiscover(externalDiscover ExternalIFaceDiscover) *mobileIFaceDiscover {
return &mobileIFaceDiscover{
externalDiscover: externalDiscover,
}
}
func (m *mobileIFaceDiscover) iFaces() ([]*transport.Interface, error) {
ifacesString, err := m.externalDiscover.IFaces()
if err != nil {
return nil, err
}
interfaces := m.parseInterfacesString(ifacesString)
return interfaces, nil
}
func (m *mobileIFaceDiscover) parseInterfacesString(interfaces string) []*transport.Interface {
ifs := []*transport.Interface{}
for _, iface := range strings.Split(interfaces, "\n") {
if strings.TrimSpace(iface) == "" {
continue
}
fields := strings.Split(iface, "|")
if len(fields) != 2 {
log.Warnf("parseInterfacesString: unable to split %q", iface)
continue
}
var name string
var index, mtu int
var up, broadcast, loopback, pointToPoint, multicast bool
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
if err != nil {
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
continue
}
newIf := net.Interface{
Name: name,
Index: index,
MTU: mtu,
}
if up {
newIf.Flags |= net.FlagUp
}
if broadcast {
newIf.Flags |= net.FlagBroadcast
}
if loopback {
newIf.Flags |= net.FlagLoopback
}
if pointToPoint {
newIf.Flags |= net.FlagPointToPoint
}
if multicast {
newIf.Flags |= net.FlagMulticast
}
ifc := transport.NewInterface(newIf)
addrs := strings.Trim(fields[1], " \n")
foundAddress := false
for _, addr := range strings.Split(addrs, " ") {
ip, ipNet, err := net.ParseCIDR(addr)
if err != nil {
log.Warnf("%s", err)
continue
}
ipNet.IP = ip
ifc.AddAddress(ipNet)
foundAddress = true
}
if foundAddress {
ifs = append(ifs, ifc)
}
}
return ifs
}

View File

@@ -0,0 +1,68 @@
package stdnet
import (
"fmt"
"testing"
)
func Test_parseInterfacesString(t *testing.T) {
testData := []struct {
name string
index int
mtu int
up bool
broadcast bool
loopBack bool
pointToPoint bool
multicast bool
addr string
}{
{"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"},
{"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"},
{"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"},
}
var exampleString string
for _, d := range testData {
exampleString = fmt.Sprintf("%s\n%s %d %d %t %t %t %t %t | %s", exampleString,
d.name,
d.index,
d.mtu,
d.up,
d.broadcast,
d.loopBack,
d.pointToPoint,
d.multicast,
d.addr)
}
d := mobileIFaceDiscover{}
nets := d.parseInterfacesString(exampleString)
if len(nets) == 0 {
t.Fatalf("failed to parse interfaces")
}
for i, net := range nets {
if net.MTU != testData[i].mtu {
t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu)
}
if net.Interface.Name != testData[i].name {
t.Errorf("invalid interface name: %s, expected: %s", net.Interface.Name, testData[i].name)
}
addr, err := net.Addrs()
if err != nil {
t.Fatal(err)
}
if len(addr) == 0 {
t.Errorf("invalid address parsing")
}
if addr[0].String() != testData[i].addr {
t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr)
}
}
}

View File

@@ -0,0 +1,36 @@
package stdnet
import (
"net"
"github.com/pion/transport/v2"
)
type pionDiscover struct {
}
func (d pionDiscover) iFaces() ([]*transport.Interface, error) {
ifs := []*transport.Interface{}
oifs, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, oif := range oifs {
ifc := transport.NewInterface(oif)
addrs, err := oif.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
ifc.AddAddress(addr)
}
ifs = append(ifs, ifc)
}
return ifs, nil
}

View File

@@ -0,0 +1,40 @@
package stdnet
import (
"strings"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
)
// InterfaceFilter is a function passed to ICE Agent to filter out not allowed interfaces
// to avoid building tunnel over them.
func InterfaceFilter(disallowList []string) func(string) bool {
return func(iFace string) bool {
if strings.HasPrefix(iFace, "lo") {
// hardcoded loopback check to support already installed agents
return false
}
for _, s := range disallowList {
if strings.HasPrefix(iFace, s) {
log.Debugf("ignoring interface %s - it is not allowed", iFace)
return false
}
}
// look for unlisted WireGuard interfaces
wg, err := wgctrl.New()
if err != nil {
log.Debugf("trying to create a wgctrl client failed with: %v", err)
return true
}
defer func() {
_ = wg.Close()
}()
_, err = wg.Device(iFace)
return err != nil
}
}

View File

@@ -0,0 +1,97 @@
// Package stdnet is an extension of the pion's stdnet.
// With it the list of the interface can come from external source.
// More info: https://github.com/golang/go/issues/40569
package stdnet
import (
"fmt"
"github.com/pion/transport/v2"
"github.com/pion/transport/v2/stdnet"
)
// Net is an implementation of the net.Net interface
// based on functions of the standard net package.
type Net struct {
stdnet.Net
interfaces []*transport.Interface
iFaceDiscover iFaceDiscover
// interfaceFilter should return true if the given interfaceName is allowed
interfaceFilter func(interfaceName string) bool
}
// NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
n := &Net{
iFaceDiscover: newMobileIFaceDiscover(iFaceDiscover),
interfaceFilter: InterfaceFilter(disallowList),
}
return n, n.UpdateInterfaces()
}
// NewNet creates a new StdNet instance.
func NewNet(disallowList []string) (*Net, error) {
n := &Net{
iFaceDiscover: pionDiscover{},
interfaceFilter: InterfaceFilter(disallowList),
}
return n, n.UpdateInterfaces()
}
// UpdateInterfaces updates the internal list of network interfaces
// and associated addresses filtering them by name.
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
// wasn't specified.
func (n *Net) UpdateInterfaces() (err error) {
allIfaces, err := n.iFaceDiscover.iFaces()
if err != nil {
return err
}
n.interfaces = n.filterInterfaces(allIfaces)
return nil
}
// Interfaces returns a slice of interfaces which are available on the
// system
func (n *Net) Interfaces() ([]*transport.Interface, error) {
return n.interfaces, nil
}
// InterfaceByIndex returns the interface specified by index.
//
// On Solaris, it returns one of the logical network interfaces
// sharing the logical data link; for more precision use
// InterfaceByName.
func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) {
for _, ifc := range n.interfaces {
if ifc.Index == index {
return ifc, nil
}
}
return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index)
}
// InterfaceByName returns the interface specified by name.
func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
for _, ifc := range n.interfaces {
if ifc.Name == name {
return ifc, nil
}
}
return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name)
}
func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.Interface {
if n.interfaceFilter == nil {
return interfaces
}
result := []*transport.Interface{}
for _, iface := range interfaces {
if n.interfaceFilter(iface.Name) {
result = append(result, iface)
}
}
return result
}

View File

@@ -6,4 +6,4 @@
#define EXPAND(x) STRINGIZE(x)
CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml
7 ICON ui/netbird.ico
wireguard.dll RCDATA wireguard.dll
wintun.dll RCDATA wintun.dll

View File

@@ -78,7 +78,7 @@ func (s *Server) Start() error {
// on failure we return error to retry
config, err := internal.UpdateConfig(s.latestConfigInput)
if errorStatus, ok := gstatus.FromError(err); ok && errorStatus.Code() == codes.NotFound {
config, err = internal.UpdateOrCreateConfig(s.latestConfigInput)
s.config, err = internal.UpdateOrCreateConfig(s.latestConfigInput)
if err != nil {
log.Warnf("unable to create configuration file: %v", err)
return err
@@ -102,7 +102,7 @@ func (s *Server) Start() error {
}
go func() {
if err := internal.RunClient(ctx, config, s.statusRecorder, nil); err != nil {
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil); err != nil {
log.Errorf("init connections: %v", err)
}
}()
@@ -223,12 +223,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
}
hostedClient := internal.NewHostedDeviceFlow(
providerConfig.ProviderConfig.Audience,
providerConfig.ProviderConfig.ClientID,
providerConfig.ProviderConfig.TokenEndpoint,
providerConfig.ProviderConfig.DeviceAuthEndpoint,
)
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
@@ -344,7 +339,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.oauthAuthFlow.expiresAt = time.Now()
s.mutex.Unlock()
if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.AccessToken); err != nil {
if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.GetTokenToUse()); err != nil {
state.Set(loginStatus)
return nil, err
}
@@ -394,7 +389,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
}
go func() {
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil); err != nil {
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil); err != nil {
log.Errorf("run client connection: %v", err)
return
}

View File

@@ -1,27 +0,0 @@
#!/bin/bash
ldir=$PWD
tmp_dir_path=$ldir/.distfiles
winnt=wireguard-nt.zip
download_file_path=$tmp_dir_path/$winnt
download_url=https://download.wireguard.com/wireguard-nt/wireguard-nt-0.10.1.zip
download_sha=772c0b1463d8d2212716f43f06f4594d880dea4f735165bd68e388fc41b81605
function resources_windows(){
cmd=$1
arch=$2
out=$3
docker run -i --rm -v $PWD:$PWD -w $PWD mstorsjo/llvm-mingw:latest $cmd -O coff -c 65001 -I $tmp_dir_path/wireguard-nt/bin/$arch -i resources.rc -o $out
}
mkdir -p $tmp_dir_path
curl -L#o $download_file_path.unverified $download_url
echo "$download_sha $download_file_path.unverified" | sha256sum -c
mv $download_file_path.unverified $download_file_path
mkdir -p .deps
unzip $download_file_path -d $tmp_dir_path
resources_windows i686-w64-mingw32-windres x86 resources_windows_386.syso
resources_windows aarch64-w64-mingw32-windres arm64 resources_windows_arm64.syso
resources_windows x86_64-w64-mingw32-windres amd64 resources_windows_amd64.syso

View File

@@ -3,10 +3,13 @@ package encryption
import (
"crypto/rand"
"fmt"
"golang.org/x/crypto/nacl/box"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const nonceSize = 24
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
// Wireguard keys are used for encryption
@@ -26,8 +29,11 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.
if err != nil {
return nil, err
}
copy(nonce[:], encryptedMsg[:24])
opened, ok := box.Open(nil, encryptedMsg[24:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
if len(encryptedMsg) < nonceSize {
return nil, fmt.Errorf("invalid encrypted message lenght")
}
copy(nonce[:], encryptedMsg[:nonceSize])
opened, ok := box.Open(nil, encryptedMsg[nonceSize:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
if !ok {
return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String())
}
@@ -36,8 +42,8 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.
}
// Generates nonce of size 24
func genNonce() (*[24]byte, error) {
var nonce [24]byte
func genNonce() (*[nonceSize]byte, error) {
var nonce [nonceSize]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, err
}

11
go.mod
View File

@@ -19,7 +19,7 @@ require (
github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.7.0
golang.org/x/sys v0.6.0
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/grpc v1.52.3
@@ -37,6 +37,7 @@ require (
github.com/getlantern/systray v1.2.1
github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0
github.com/google/go-cmp v0.5.9
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
@@ -47,6 +48,8 @@ require (
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/open-policy-agent/opa v0.49.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2
github.com/pion/stun v0.4.0
github.com/pion/transport/v2 v2.0.2
github.com/prometheus/client_golang v1.14.0
github.com/rs/xid v1.3.0
@@ -55,6 +58,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
go.opentelemetry.io/otel/metric v0.33.0
go.opentelemetry.io/otel/sdk/metric v0.33.0
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf
golang.org/x/net v0.8.0
golang.org/x/sync v0.1.0
golang.org/x/term v0.6.0
@@ -89,7 +93,6 @@ require (
github.com/go-stack/stack v1.8.0 // indirect
github.com/gobwas/glob v0.2.3 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/gopacket v1.1.19 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -102,10 +105,8 @@ require (
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.6 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns v0.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/stun v0.4.0 // indirect
github.com/pion/turn/v2 v2.1.0 // indirect
github.com/pion/udp/v2 v2.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -126,12 +127,10 @@ require (
go.opentelemetry.io/otel v1.11.1 // indirect
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect
golang.org/x/mod v0.8.0 // indirect
golang.org/x/text v0.8.0 // indirect
golang.org/x/tools v0.6.0 // indirect
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d // indirect
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect

5
go.sum
View File

@@ -881,13 +881,12 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 h1:3zl8RkJNQ8wfPRomwv/6DBbH2Ut6dgMaWTxM0ZunWnE=
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI=
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU=
golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno=
golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ=

208
iface/bind/bind.go Normal file
View File

@@ -0,0 +1,208 @@
package bind
import (
"errors"
"fmt"
"github.com/pion/stun"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"net"
"net/netip"
"sync"
"syscall"
)
// ICEBind is the userspace implementation of WireGuard's conn.Bind interface using ice.UDPMux of the pion/ice library
type ICEBind struct {
// below fields, initialized on open
sharedConn net.PacketConn
udpMux *UniversalUDPMuxDefault
// below are fields initialized on creation
transportNet transport.Net
mu sync.Mutex
}
// NewICEBind create a new instance of ICEBind with a given transportNet and an interfaceFilter function.
// The interfaceFilter function is used to exclude interfaces from hole punching (the IPs of that interfaces won't be used as connection candidates)
// The transportNet can be nil.
func NewICEBind(transportNet transport.Net) *ICEBind {
return &ICEBind{
transportNet: transportNet,
mu: sync.Mutex{},
}
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (b *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.udpMux, nil
}
// Open creates a WireGuard socket and an instance of UDPMux that is used to glue up ICE and WireGuard for hole punching
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.sharedConn != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
ipv4Conn, _, err := listenNet("udp4", int(uport))
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.sharedConn = ipv4Conn
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn, Net: b.transportNet})
portAddr1, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
if err != nil {
return nil, 0, err
}
log.Infof("opened ICEBind on %s", ipv4Conn.LocalAddr().String())
return []conn.ReceiveFunc{
b.makeReceiveIPv4(b.sharedConn),
},
portAddr1.Port(), nil
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: append([]byte{}, raw...),
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
return func(buff []byte) (int, conn.Endpoint, error) {
n, endpoint, err := c.ReadFrom(buff)
if err != nil {
return 0, nil, err
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
if !stun.IsMessage(buff[:20]) {
// WireGuard traffic
return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil
}
msg, err := parseSTUNMessage(buff[:n])
if err != nil {
return 0, nil, err
}
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
if err != nil {
log.Warnf("failed to handle packet")
}
// discard packets because they are STUN related
return 0, nil, nil //todo proper return
}
}
// Close closes the WireGuard socket and UDPMux
func (b *ICEBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
var err1, err2 error
if b.sharedConn != nil {
c := b.sharedConn
b.sharedConn = nil
err1 = c.Close()
}
if b.udpMux != nil {
m := b.udpMux
b.udpMux = nil
err2 = m.Close()
}
if err1 != nil {
return err1
}
return err2
}
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
func (b *ICEBind) SetMark(mark uint32) error {
return nil
}
// Send bytes to the remote endpoint (peer)
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
nend, ok := endpoint.(conn.StdNetEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
addrPort := netip.AddrPort(nend)
_, err := b.sharedConn.WriteTo(buff, &net.UDPAddr{
IP: addrPort.Addr().AsSlice(),
Port: int(addrPort.Port()),
Zone: addrPort.Addr().Zone(),
})
return err
}
// ParseEndpoint creates a new endpoint from a string.
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
e, err := netip.ParseAddrPort(s)
return asEndpoint(e), err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]conn.Endpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) conn.Endpoint {
m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = conn.Endpoint(conn.StdNetEndpoint(ap))
m[ap] = e
}
return e
}

444
iface/bind/udp_mux.go Normal file
View File

@@ -0,0 +1,444 @@
package bind
import (
"fmt"
"github.com/pion/ice/v2"
"github.com/pion/stun"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"io"
"net"
"strings"
"sync"
"github.com/pion/logging"
"github.com/pion/transport/v2"
)
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
closedChan chan struct{}
closeOnce sync.Once
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
mu sync.Mutex
// for UDP connection listen at unspecified address
localAddrsForUnspecified []net.Addr
}
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
// Required for gathering local addresses
// in case a un UDPConn is passed which does not
// bind to a specific local address.
Net transport.Net
InterfaceFilter func(interfaceName string) bool
}
func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []ice.NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
ips := []net.IP{}
ifaces, err := n.Interfaces()
if err != nil {
return ips, err
}
var IPv4Requested, IPv6Requested bool
for _, typ := range networkTypes {
if typ.IsIPv4() {
IPv4Requested = true
}
if typ.IsIPv6() {
IPv6Requested = true
}
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
continue // loopback interface
}
if interfaceFilter != nil && !interfaceFilter(iface.Name) {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch addr := addr.(type) {
case *net.IPNet:
ip = addr.IP
case *net.IPAddr:
ip = addr.IP
}
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
continue
}
if ipv4 := ip.To4(); ipv4 == nil {
if !IPv6Requested {
continue
} else if !isSupportedIPv6(ip) {
continue
}
} else if !IPv4Requested {
continue
}
if ipFilter != nil && !ipFilter(ip) {
continue
}
ips = append(ips, ip)
}
}
return ips, nil
}
// The conditions of invalidation written below are defined in
// https://tools.ietf.org/html/rfc8445#section-5.1.1.1
func isSupportedIPv6(ip net.IP) bool {
if len(ip) != net.IPv6len ||
isZeros(ip[0:12]) || // !(IPv4-compatible IPv6)
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast)
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() {
return false
}
return true
}
func isZeros(ip net.IP) bool {
for i := 0; i < len(ip); i++ {
if ip[i] != 0 {
return false
}
}
return true
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
var localAddrsForUnspecified []net.Addr
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
localAddrsForUnspecified: localAddrsForUnspecified,
}
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified
}
return []net.Addr{m.LocalAddr()}
}
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
m.mu.Lock()
defer m.mu.Unlock()
if m.IsClosed() {
return nil, io.ErrClosedPipe
}
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
}
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
m.connsIPv4[ufrag] = c
}
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
}
m.mu.Unlock()
if len(removedConns) == 0 {
// No need to lock if no connection was found
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
if connList, ok := m.addressMap[addr]; ok {
var newList []*udpMuxedConn
for _, conn := range connList {
if conn.params.Key != ufrag {
newList = append(newList, conn)
}
}
m.addressMap[addr] = newList
}
}
}
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
default:
return false
}
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
defer m.mu.Unlock()
for _, c := range m.connsIPv4 {
_ = c.Close()
}
for _, c := range m.connsIPv6 {
_ = c.Close()
}
m.connsIPv4 = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]*udpMuxedConn)
close(m.closedChan)
_ = m.params.UDPConn.Close()
})
return err
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.Unlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}
type bufferHolder struct {
buf []byte
}
func newBufferHolder(size int) *bufferHolder {
return &bufferHolder{
buf: make([]byte, size),
}
}

View File

@@ -0,0 +1,253 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
*/
import (
"fmt"
log "github.com/sirupsen/logrus"
"net"
"time"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/v2"
)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
// stun.XORMappedAddress indexed by the STUN server addr
xorMappedMap map[string]*xorMapped
}
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
type UniversalUDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25
}
m := &UniversalUDPMuxDefault{
params: params,
xorMappedMap: make(map[string]*xorMapped),
}
// wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
}
// GetListenAddresses returns the listen addr of this UDP
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()}
}
// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr.
// Not implemented yet.
func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) {
return nil, fmt.Errorf("not implemented yet")
}
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
// All other STUN packets will be forwarded to the UDPMux
func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// message about this err will be logged in the UDPMux
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {
log.Debugf("%s: %v", fmt.Errorf("failed to get XOR-MAPPED-ADDRESS response"), err)
return nil
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
m.mu.Lock()
defer m.mu.Unlock()
// check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
_, ok := m.xorMappedMap[stunAddr]
_, err := msg.Get(stun.AttrXORMappedAddress)
return err == nil && ok
}
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
// and set the mapped address for the server
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
m.mu.Lock()
defer m.mu.Unlock()
mappedAddr, ok := m.xorMappedMap[stunAddr.String()]
if !ok {
return fmt.Errorf("no XOR address mapping")
}
var addr stun.XORMappedAddress
if err := addr.GetFrom(msg); err != nil {
return err
}
m.xorMappedMap[stunAddr.String()] = mappedAddr
mappedAddr.SetAddr(&addr)
return nil
}
// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server.
// Makes a STUN binding request to discover mapped address otherwise.
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
m.mu.Lock()
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
// if we already have a mapping for this STUN server (address already received)
// and if it is not too old we return it without making a new request to STUN server
if ok {
if mappedAddr.expired() {
mappedAddr.closeWaiters()
delete(m.xorMappedMap, serverAddr.String())
ok = false
} else if mappedAddr.pending() {
ok = false
}
}
m.mu.Unlock()
if ok {
return mappedAddr.addr, nil
}
// otherwise, make a STUN request to discover the address
// or wait for already sent request to complete
waitAddrReceived, err := m.sendStun(serverAddr)
if err != nil {
return nil, fmt.Errorf("%s: %s", "failed to send STUN packet", err)
}
// block until response was handled by the connWorker routine and XORMappedAddress was updated
select {
case <-waitAddrReceived:
// when channel closed, addr was obtained
m.mu.Lock()
mappedAddr := *m.xorMappedMap[serverAddr.String()]
m.mu.Unlock()
if mappedAddr.addr == nil {
return nil, fmt.Errorf("no XOR address mapping")
}
return mappedAddr.addr, nil
case <-time.After(deadline):
return nil, fmt.Errorf("timeout while waiting for XORMappedAddr")
}
}
// sendStun sends a STUN request via UDP conn.
//
// The returned channel is closed when the STUN response has been received.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) {
m.mu.Lock()
defer m.mu.Unlock()
// if record present in the map, we already sent a STUN request,
// just wait when waitAddrReceived will be closed
addrMap, ok := m.xorMappedMap[serverAddr.String()]
if !ok {
addrMap = &xorMapped{
expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL),
waitAddrReceived: make(chan struct{}),
}
m.xorMappedMap[serverAddr.String()] = addrMap
}
req, err := stun.Build(stun.BindingRequest, stun.TransactionID)
if err != nil {
return nil, err
}
if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil {
return nil, err
}
return addrMap.waitAddrReceived, nil
}
type xorMapped struct {
addr *stun.XORMappedAddress
waitAddrReceived chan struct{}
expiresAt time.Time
}
func (a *xorMapped) closeWaiters() {
select {
case <-a.waitAddrReceived:
// notify was close, ok, that means we received duplicate response
// just exit
break
default:
// notify tha twe have a new addr
close(a.waitAddrReceived)
}
}
func (a *xorMapped) pending() bool {
return a.addr == nil
}
func (a *xorMapped) expired() bool {
return a.expiresAt.Before(time.Now())
}
func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) {
a.addr = addr
a.closeWaiters()
}

View File

@@ -0,0 +1,233 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
import (
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/transport/v2/packetio"
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
type udpMuxedConn struct {
params *udpMuxedConnParams
// remote addresses that we have sent to on this conn
addresses []string
// channel holding incoming packets
buf *packetio.Buffer
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
}
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
params: params,
buf: packetio.NewBuffer(),
closedChan: make(chan struct{}),
}
return p
}
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// read address
total, err := c.buf.Read(buf.buf)
if err != nil {
return 0, nil, err
}
dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
if dataLen > total || dataLen > len(b) {
return 0, nil, io.ErrShortBuffer
}
// read data and then address
offset := 2
copy(b, buf.buf[offset:offset+dataLen])
offset += dataLen
// read address len & decode address
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
offset += 2
if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
return 0, nil, err
}
return dataLen, rAddr, nil
}
func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
if c.isClosed() {
return 0, io.ErrClosedPipe
}
// each time we write to a new address, we'll register it with the mux
addr := rAddr.String()
if !c.containsAddress(addr) {
c.addAddress(addr)
}
return c.params.Mux.writeTo(buf, rAddr)
}
func (c *udpMuxedConn) LocalAddr() net.Addr {
return c.params.LocalAddr
}
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
return c.closedChan
}
func (c *udpMuxedConn) Close() error {
var err error
c.closeOnce.Do(func() {
err = c.buf.Close()
close(c.closedChan)
})
return err
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
default:
return false
}
}
func (c *udpMuxedConn) getAddresses() []string {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}
func (c *udpMuxedConn) addAddress(addr string) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
// map it on mux
c.params.Mux.registerConnForAddress(c, addr)
}
func (c *udpMuxedConn) containsAddress(addr string) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
if addr == a {
return true
}
}
return false
}
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
// write two packets, address and data
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// format of buffer | data len | data bytes | addr len | addr bytes |
if len(buf.buf) < len(data)+maxAddrSize {
return io.ErrShortBuffer
}
// data len
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
offset := 2
// data
copy(buf.buf[offset:], data)
offset += len(data)
// write address first, leaving room for its length
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
if err != nil {
return err
}
total := offset + n + 2
// address len
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
if _, err := c.buf.Write(buf.buf[:total]); err != nil {
return err
}
return nil
}
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
ipData, err := addr.IP.MarshalText()
if err != nil {
return 0, err
}
total := 2 + len(ipData) + 2 + len(addr.Zone)
if total > len(buf) {
return 0, io.ErrShortBuffer
}
binary.LittleEndian.PutUint16(buf, uint16(len(ipData)))
offset := 2
n := copy(buf[offset:], ipData)
offset += n
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2
copy(buf[offset:], addr.Zone)
return total, nil
}
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := net.UDPAddr{}
offset := 0
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
offset += 2
// basic bounds checking
if ipLen+offset > len(buf) {
return nil, io.ErrShortBuffer
}
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
return nil, err
}
offset += ipLen
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
offset += 2
zone := make([]byte, len(buf[offset:]))
copy(zone, buf[offset:])
addr.Zone = string(zone)
return &addr, nil
}

View File

@@ -5,20 +5,34 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/iface/bind"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const (
DefaultMTU = 1280
DefaultWgPort = 51820
defaultWgKeepAlive = 25 * time.Second
)
// WGIface represents a interface instance
type WGIface struct {
tun *tunDevice
configurer wGConfigurer
mu sync.Mutex
tun *tunDevice
configurer wGConfigurer
mu sync.Mutex
userspaceBind bool
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind
}
// GetBind returns a userspace implementation of WireGuard Bind interface
func (w *WGIface) GetBind() *bind.ICEBind {
return w.tun.iceBind
}
// Create creates a new Wireguard interface, sets a given IP and brings it up.
@@ -26,7 +40,7 @@ type WGIface struct {
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("create Wireguard interface %s", w.tun.DeviceName())
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
return w.tun.Create()
}
@@ -64,12 +78,12 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, endpoint *net.UDPAddr) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s: endpoint %s ", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
return w.configurer.updatePeer(peerKey, allowedIps, endpoint)
}
// RemovePeer removes a Wireguard Peer from the interface iface

View File

@@ -1,22 +1,28 @@
package iface
import "sync"
import (
"sync"
// NewWGIFace Creates a new Wireguard interface instance
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) {
wgIface := &WGIface{
"github.com/pion/transport/v2"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, preSharedKey *wgtypes.Key, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{
mu: sync.Mutex{},
}
wgAddress, err := parseWGAddress(address)
if err != nil {
return wgIface, err
return wgIFace, err
}
tun := newTunDevice(wgAddress, mtu, tunAdapter)
wgIface.tun = tun
tun := newTunDevice(wgAddress, mtu, tunAdapter, transportNet)
wgIFace.tun = tun
wgIface.configurer = newWGConfigurer(tun)
wgIFace.configurer = newWGConfigurer(tun, preSharedKey)
return wgIface, nil
wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil
}

View File

@@ -2,21 +2,27 @@
package iface
import "sync"
import (
"github.com/pion/transport/v2"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
// NewWGIFace Creates a new Wireguard interface instance
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter) (*WGIface, error) {
wgIface := &WGIface{
"sync"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, preSharedKey *wgtypes.Key, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{
mu: sync.Mutex{},
}
wgAddress, err := parseWGAddress(address)
if err != nil {
return wgIface, err
return wgIFace, err
}
wgIface.tun = newTunDevice(ifaceName, wgAddress, mtu)
wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet)
wgIface.configurer = newWGConfigurer(ifaceName)
return wgIface, nil
wgIFace.configurer = newWGConfigurer(iFaceName, preSharedKey)
wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil
}

View File

@@ -2,6 +2,7 @@ package iface
import (
"fmt"
"github.com/pion/transport/v2/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
@@ -32,7 +33,12 @@ func init() {
func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8"
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
@@ -92,7 +98,11 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
@@ -121,7 +131,11 @@ func Test_CreateInterface(t *testing.T) {
func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
@@ -149,7 +163,11 @@ func Test_Close(t *testing.T) {
func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
@@ -196,7 +214,11 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
@@ -255,7 +277,11 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30"
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
@@ -304,8 +330,11 @@ func Test_ConnectPeers(t *testing.T) {
peer2Key, _ := wgtypes.GeneratePrivateKey()
keepAlive := 1 * time.Second
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}
@@ -322,7 +351,11 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil)
newNet, err = stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, newNet)
if err != nil {
t.Fatal(err)
}

View File

@@ -33,7 +33,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
if p.PresharedKey != nil {
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
sb.WriteString(fmt.Sprintf("public_key=%s\n", preSharedHexKey))
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
}
if p.Remove {

View File

@@ -3,7 +3,7 @@
package iface
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
return false
}

View File

@@ -33,6 +33,7 @@ const (
loading
live
inuse
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
)
type module struct {
@@ -81,9 +82,15 @@ func tunModuleIsLoaded() bool {
return tunLoaded
}
// WireguardModuleIsLoaded check if we can load wireguard mod (linux only)
func WireguardModuleIsLoaded() bool {
if canCreateFakeWireguardInterface() {
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
if canCreateFakeWireGuardInterface() {
return true
}
@@ -96,7 +103,7 @@ func WireguardModuleIsLoaded() bool {
return loaded
}
func canCreateFakeWireguardInterface() bool {
func canCreateFakeWireGuardInterface() bool {
link := newWGLink("mustnotexist")
// We willingly try to create a device with an invalid

View File

@@ -1,11 +1,12 @@
package iface
import (
"github.com/netbirdio/netbird/iface/bind"
"github.com/pion/transport/v2"
"net"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
@@ -16,17 +17,19 @@ type tunDevice struct {
mtu int
tunAdapter TunAdapter
fd int
name string
device *device.Device
uapi net.Listener
fd int
name string
device *device.Device
uapi net.Listener
iceBind *bind.ICEBind
}
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter) *tunDevice {
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
return &tunDevice{
address: address,
mtu: mtu,
tunAdapter: tunAdapter,
iceBind: bind.NewICEBind(transportNet),
}
}
@@ -46,7 +49,7 @@ func (t *tunDevice) Create() error {
t.name = name
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(tunDevice, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device.DisableSomeRoamingForBrokenMobileSemantics()
log.Debugf("create uapi")

View File

@@ -11,7 +11,7 @@ import (
)
func (c *tunDevice) Create() error {
if WireguardModuleIsLoaded() {
if WireGuardModuleIsLoaded() {
log.Info("using kernel WireGuard")
return c.createWithKernel()
}
@@ -30,7 +30,7 @@ func (c *tunDevice) Create() error {
}
// createWithKernel Creates a new Wireguard interface using kernel Wireguard module.
// createWithKernel Creates a new WireGuard interface using kernel WireGuard module.
// Works for Linux and offers much better network performance
func (c *tunDevice) createWithKernel() error {

View File

@@ -3,11 +3,12 @@
package iface
import (
"github.com/netbirdio/netbird/iface/bind"
"github.com/pion/transport/v2"
"net"
"os"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
@@ -18,13 +19,15 @@ type tunDevice struct {
address WGAddress
mtu int
netInterface NetInterface
iceBind *bind.ICEBind
}
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice {
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
return &tunDevice{
name: name,
address: address,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
}
}
@@ -69,7 +72,7 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
}
// We need to create a wireguard-go device and listen to configuration requests
tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
tunDevice := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
err = tunDevice.Up()
if err != nil {
return tunIface, err

View File

@@ -2,26 +2,32 @@ package iface
import (
"fmt"
"net"
"github.com/netbirdio/netbird/iface/bind"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"net"
)
type tunDevice struct {
name string
address WGAddress
netInterface NetInterface
uapi net.Listener
iceBind *bind.ICEBind
mtu int
}
func newTunDevice(name string, address WGAddress, mtu int) *tunDevice {
return &tunDevice{name: name, address: address}
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
return &tunDevice{name: name, address: address, iceBind: bind.NewICEBind(transportNet), mtu: mtu}
}
func (c *tunDevice) Create() error {
var err error
c.netInterface, err = c.createAdapter()
c.netInterface, err = c.createWithUserspace()
if err != nil {
return err
}
@@ -29,6 +35,41 @@ func (c *tunDevice) Create() error {
return c.assignAddr()
}
// createWithUserspace Creates a new WireGuard interface, using wireguard-go userspace implementation
func (c *tunDevice) createWithUserspace() (NetInterface, error) {
tunIface, err := tun.CreateTUN(c.name, c.mtu)
if err != nil {
return nil, err
}
// We need to create a wireguard-go device and listen to configuration requests
tunDevice := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDevice.Up()
if err != nil {
return tunIface, err
}
uapi, err := c.getUAPI(c.name)
if err != nil {
return nil, err
}
c.uapi = uapi
go func() {
for {
uapiConn, uapiErr := uapi.Accept()
if uapiErr != nil {
log.Traceln("uapi accept failed with error: ", uapiErr)
continue
}
go tunDevice.IpcHandle(uapiConn)
}
}()
log.Debugln("UAPI listener started")
return tunIface, nil
}
func (c *tunDevice) UpdateAddr(address WGAddress) error {
c.address = address
return c.assignAddr()
@@ -43,19 +84,28 @@ func (c *tunDevice) DeviceName() string {
}
func (c *tunDevice) Close() error {
if c.netInterface == nil {
return nil
var err1, err2 error
if c.netInterface != nil {
err1 = c.netInterface.Close()
}
return c.netInterface.Close()
if c.uapi != nil {
err2 = c.uapi.Close()
}
if err1 != nil {
return err1
}
return err2
}
func (c *tunDevice) getInterfaceGUIDString() (string, error) {
if c.netInterface == nil {
return "", fmt.Errorf("interface has not been initialized yet")
}
windowsDevice := c.netInterface.(*driver.Adapter)
luid := windowsDevice.LUID()
windowsDevice := c.netInterface.(*tun.NativeTun)
luid := winipcfg.LUID(windowsDevice.LUID())
guid, err := luid.GUID()
if err != nil {
return "", err
@@ -63,31 +113,15 @@ func (c *tunDevice) getInterfaceGUIDString() (string, error) {
return guid.String(), nil
}
func (c *tunDevice) createAdapter() (NetInterface, error) {
WintunStaticRequestedGUID, _ := windows.GenerateGUID()
adapter, err := driver.CreateAdapter(c.name, "WireGuard", &WintunStaticRequestedGUID)
if err != nil {
err = fmt.Errorf("error creating adapter: %w", err)
return nil, err
}
err = adapter.SetAdapterState(driver.AdapterStateUp)
if err != nil {
return adapter, err
}
state, _ := adapter.LUID().GUID()
log.Debugln("device guid: ", state.String())
return adapter, nil
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (c *tunDevice) assignAddr() error {
luid := c.netInterface.(*driver.Adapter).LUID()
tunDev := c.netInterface.(*tun.NativeTun)
luid := winipcfg.LUID(tunDev.LUID())
log.Debugf("adding address %s to interface: %s", c.address.IP, c.name)
err := luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}})
if err != nil {
return err
}
return nil
return luid.SetIPAddresses([]net.IPNet{{c.address.IP, c.address.Network.Mask}})
}
// getUAPI returns a Listener
func (c *tunDevice) getUAPI(iface string) (net.Listener, error) {
return ipc.UAPIListen(iface)
}

View File

@@ -3,7 +3,6 @@ package iface
import (
"errors"
"net"
"time"
log "github.com/sirupsen/logrus"
@@ -15,12 +14,14 @@ var (
)
type wGConfigurer struct {
tunDevice *tunDevice
tunDevice *tunDevice
preSharedKey *wgtypes.Key
}
func newWGConfigurer(tunDevice *tunDevice) wGConfigurer {
func newWGConfigurer(tunDevice *tunDevice, preSharedKey *wgtypes.Key) wGConfigurer {
return wGConfigurer{
tunDevice: tunDevice,
tunDevice: tunDevice,
preSharedKey: preSharedKey,
}
}
@@ -41,7 +42,7 @@ func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
return c.tunDevice.Device().IpcSet(toWgUserspaceString(config))
}
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, endpoint *net.UDPAddr) error {
//parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
@@ -52,12 +53,13 @@ func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive t
if err != nil {
return err
}
keepalive := defaultWgKeepAlive
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
PresharedKey: c.preSharedKey,
Endpoint: endpoint,
}

View File

@@ -5,7 +5,6 @@ package iface
import (
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
@@ -13,12 +12,14 @@ import (
)
type wGConfigurer struct {
deviceName string
deviceName string
preSharedKey *wgtypes.Key
}
func newWGConfigurer(deviceName string) wGConfigurer {
func newWGConfigurer(deviceName string, preSharedKey *wgtypes.Key) wGConfigurer {
return wGConfigurer{
deviceName: deviceName,
deviceName: deviceName,
preSharedKey: preSharedKey,
}
}
@@ -43,7 +44,7 @@ func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
return nil
}
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, endpoint *net.UDPAddr) error {
//parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
@@ -54,12 +55,13 @@ func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive t
if err != nil {
return err
}
keepalive := defaultWgKeepAlive
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
PersistentKeepaliveInterval: &keepalive,
PresharedKey: c.preSharedKey,
Endpoint: endpoint,
}

View File

@@ -3,26 +3,31 @@
# Management API
# Management API port
NETBIRD_MGMT_API_PORT=33073
NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073}
# Management API endpoint address, used by the Dashboard
NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
# Management Certficate file path. These are generated by the Dashboard container
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem"
NETBIRD_LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/fullchain.pem"
# Management Certficate key file path.
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem"
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/privkey.pem"
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
# Turn credentials
# Signal
NETBIRD_SIGNAL_PROTOCOL="http"
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000}
# Turn credentials
# User
TURN_USER=self
# Password. If empty, the configure.sh will generate one with openssl
TURN_PASSWORD=
# Min port
TURN_MIN_PORT=49152
TURN_MIN_PORT=${TURN_MIN_PORT:-49152}
# Max port
TURN_MAX_PORT=65535
TURN_MAX_PORT=${TURN_MAX_PORT:-65535}
VOLUME_PREFIX="netbird-"
MGMT_VOLUMESUFFIX="mgmt"
@@ -30,8 +35,13 @@ SIGNAL_VOLUMESUFFIX="signal"
LETSENCRYPT_VOLUMESUFFIX="letsencrypt"
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=${NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE:-$NETBIRD_AUTH_AUDIENCE}
NETBIRD_AUTH_DEVICE_AUTH_SCOPE=${NETBIRD_AUTH_DEVICE_AUTH_SCOPE:-openid}
NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN=${NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN:-false}
NETBIRD_DISABLE_ANONYMOUS_METRICS=${NETBIRD_DISABLE_ANONYMOUS_METRICS:-false}
NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
# exports
export NETBIRD_DOMAIN
@@ -44,6 +54,7 @@ export NETBIRD_AUTH_JWT_CERTS
export NETBIRD_LETSENCRYPT_EMAIL
export NETBIRD_MGMT_API_PORT
export NETBIRD_MGMT_API_ENDPOINT
export NETBIRD_LETSENCRYPT_DOMAIN
export NETBIRD_MGMT_API_CERT_FILE
export NETBIRD_MGMT_API_CERT_KEY_FILE
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER
@@ -61,4 +72,11 @@ export SIGNAL_VOLUMESUFFIX
export LETSENCRYPT_VOLUMESUFFIX
export NETBIRD_DISABLE_ANONYMOUS_METRICS
export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN
export NETBIRD_MGMT_DNS_DOMAIN
export NETBIRD_MGMT_DNS_DOMAIN
export NETBIRD_SIGNAL_PROTOCOL
export NETBIRD_SIGNAL_PORT
export NETBIRD_AUTH_USER_ID_CLAIM
export NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
export NETBIRD_TOKEN_SOURCE
export NETBIRD_AUTH_DEVICE_AUTH_SCOPE
export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN

View File

@@ -121,6 +121,30 @@ if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
fi
# Check if letsencrypt was disabled
if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]
then
export NETBIRD_DASHBOARD_ENDPOINT="https://$NETBIRD_DOMAIN:443"
export NETBIRD_SIGNAL_ENDPOINT="https://$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT"
echo "Letsencrypt was disabled, the Https-endpoints cannot be used anymore"
echo " and a reverse-proxy with Https needs to be placed in front of netbird!"
echo "The following forwards have to be setup:"
echo "- $NETBIRD_DASHBOARD_ENDPOINT -http-> dashboard:80"
echo "- $NETBIRD_MGMT_API_ENDPOINT/api -http-> management:$NETBIRD_MGMT_API_PORT"
echo "- $NETBIRD_MGMT_API_ENDPOINT/management.ManagementService/ -grpc-> management:$NETBIRD_MGMT_API_PORT"
echo "- $NETBIRD_SIGNAL_ENDPOINT/signalexchange.SignalExchange/ -grpc-> signal:80"
echo "You most likely also have to change NETBIRD_MGMT_API_ENDPOINT in base.setup.env and port-mappings in docker-compose.yml.tmpl and rerun this script."
echo " The target of the forwards depends on your setup. Beware of the gRPC protocol instead of http for management and signal!"
echo "You are also free to remove any occurences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
echo ""
export NETBIRD_SIGNAL_PROTOCOL="https"
unset NETBIRD_LETSENCRYPT_DOMAIN
unset NETBIRD_MGMT_API_CERT_FILE
unset NETBIRD_MGMT_API_CERT_KEY_FILE
fi
env | grep NETBIRD
envsubst < docker-compose.yml.tmpl > docker-compose.yml

View File

@@ -8,20 +8,26 @@ services:
- 80:80
- 443:443
environment:
# Endpoints
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
# OIDC
- AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
- USE_AUTH0=$NETBIRD_USE_AUTH0
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
- NGINX_SSL_PORT=443
- LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
- NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE
# SSL
- NGINX_SSL_PORT=443
# Letsencrypt
- LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
volumes:
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
# Signal
signal:
image: netbirdio/signal:latest
@@ -32,7 +38,8 @@ services:
- 10000:80
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
# Management
management:
image: netbirdio/management:latest
@@ -46,8 +53,15 @@ services:
ports:
- $NETBIRD_MGMT_API_PORT:443 #API port
# # command for Let's Encrypt validation without dashboard container
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"]
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
command: [
"--port", "443",
"--log-file", "console",
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
]
# Coturn
coturn:
image: coturn/coturn
@@ -60,6 +74,7 @@ services:
network_mode: host
command:
- -c /etc/turnserver.conf
volumes:
$MGMT_VOLUMENAME:
$SIGNAL_VOLUMENAME:

View File

@@ -0,0 +1,99 @@
version: "3"
services:
#UI dashboard
dashboard:
image: wiretrustee/dashboard:latest
restart: unless-stopped
#ports:
# - 80:80
# - 443:443
environment:
# Endpoints
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
# OIDC
- AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
- USE_AUTH0=$NETBIRD_USE_AUTH0
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
# SSL
- NGINX_SSL_PORT=443
# Letsencrypt
- LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
volumes:
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
labels:
- traefik.enable=true
- traefik.http.routers.netbird-dashboard.rule=Host(`$NETBIRD_DOMAIN`)
- traefik.http.services.netbird-dashboard.loadbalancer.server.port=80
# Signal
signal:
image: netbirdio/signal:latest
restart: unless-stopped
volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird
#ports:
# - 10000:80
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
labels:
- traefik.enable=true
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
- traefik.http.services.netbird-signal.loadbalancer.server.port=80
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
# Management
management:
image: netbirdio/management:latest
restart: unless-stopped
depends_on:
- dashboard
volumes:
- $MGMT_VOLUMENAME:/var/lib/netbird
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
- ./management.json:/etc/netbird/management.json
#ports:
# - $NETBIRD_MGMT_API_PORT:443 #API port
# # command for Let's Encrypt validation without dashboard container
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
command: [
"--port", "443",
"--log-file", "console",
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
]
labels:
- traefik.enable=true
- traefik.http.routers.netbird-api.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/api`)
- traefik.http.routers.netbird-api.service=netbird-api
- traefik.http.services.netbird-api.loadbalancer.server.port=443
- traefik.http.routers.netbird-management.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/management.ManagementService/`)
- traefik.http.routers.netbird-management.service=netbird-management
- traefik.http.services.netbird-management.loadbalancer.server.port=443
- traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c
# Coturn
coturn:
image: coturn/coturn
restart: unless-stopped
domainname: $NETBIRD_DOMAIN
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
# - ./cert.pem:/etc/coturn/certs/cert.pem:ro
network_mode: host
command:
- -c /etc/turnserver.conf
volumes:
$MGMT_VOLUMENAME:
$SIGNAL_VOLUMENAME:
$LETSENCRYPT_VOLUMENAME:

View File

@@ -21,8 +21,8 @@
"TimeBasedCredentials": false
},
"Signal": {
"Proto": "http",
"URI": "$NETBIRD_DOMAIN:10000",
"Proto": "$NETBIRD_SIGNAL_PROTOCOL",
"URI": "$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT",
"Username": "",
"Password": null
},
@@ -43,11 +43,13 @@
"DeviceAuthorizationFlow": {
"Provider": "$NETBIRD_AUTH_DEVICE_AUTH_PROVIDER",
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_AUDIENCE",
"Audience": "$NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE",
"Domain": "$NETBIRD_AUTH0_DOMAIN",
"ClientID": "$NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID",
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT"
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
"Scope": "$NETBIRD_AUTH_DEVICE_AUTH_SCOPE",
"UseIDToken": $NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
}
}
}

View File

@@ -2,7 +2,11 @@
##
# Dashboard domain. e.g. app.mydomain.com
NETBIRD_DOMAIN=""
# OIDC configuration e.g., https://example.eu.auth0.com/.well-known/openid-configuration
# -------------------------------------------
# OIDC
# e.g., https://example.eu.auth0.com/.well-known/openid-configuration
# -------------------------------------------
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
NETBIRD_AUTH_AUDIENCE=""
# e.g. netbird-client
@@ -13,14 +17,30 @@ NETBIRD_AUTH_CLIENT_ID=""
NETBIRD_USE_AUTH0="false"
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID=""
# e.g. hello@mydomain.com
NETBIRD_LETSENCRYPT_EMAIL=""
# Some IDPs requires different audience, scopes and to use id token for device authorization flow
# you can customize here:
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid"
NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN=false
# if your IDP provider doesn't support fragmented URIs, configure custom
# redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain.
# NETBIRD_AUTH_REDIRECT_URI="/peers"
# NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers"
# Updates the preference to use id tokens instead of access token on dashboard
# Okta and Gitlab IDPs can benefit from this
# NETBIRD_TOKEN_SOURCE="idToken"
# -------------------------------------------
# Letsencrypt
# -------------------------------------------
# Disable letsencrypt
# if disabled, cannot use HTTPS anymore and requires setting up a reverse-proxy to do it instead
NETBIRD_DISABLE_LETSENCRYPT=false
# e.g. hello@mydomain.com
NETBIRD_LETSENCRYPT_EMAIL=""
# Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted

View File

@@ -11,4 +11,9 @@ NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0
NETBIRD_AUTH_AUDIENCE=$CI_NETBIRD_AUTH_AUDIENCE
# e.g. hello@mydomain.com
NETBIRD_LETSENCRYPT_EMAIL=""
NETBIRD_AUTH_REDIRECT_URI="/peers"
NETBIRD_AUTH_REDIRECT_URI="/peers"
NETBIRD_DISABLE_LETSENCRYPT=true
NETBIRD_TOKEN_SOURCE="idToken"
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE="super"
NETBIRD_AUTH_USER_ID_CLAIM="email"
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email"

View File

@@ -20,6 +20,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
@@ -144,15 +145,19 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
// blocking until error
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
if err != nil {
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
s, _ := gstatus.FromError(err)
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
backOff.Reset() // reset backoff counter after successful connection
c.notifyDisconnected()
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
// reset times and next try will start with a long delay
backOff.Reset()
c.notifyDisconnected()
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
return nil

View File

@@ -19,25 +19,28 @@ import (
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
@@ -179,13 +182,22 @@ var (
tlsEnabled = true
}
jwtValidator, err := jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
)
if err != nil {
return fmt.Errorf("failed creating JWT validator: %v", err)
}
httpAPIAuthCfg := httpapi.AuthCfg{
Issuer: config.HttpConfig.AuthIssuer,
Audience: config.HttpConfig.AuthAudience,
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg)
httpAPIHandler, err := httpapi.APIHandler(accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
@@ -405,6 +417,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
u.Host, config.DeviceAuthorizationFlow.ProviderConfig.Domain)
config.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if config.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
config.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope
}
}
}

View File

@@ -1,15 +1,15 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.12.4
// protoc v3.21.9
// source: management.proto
package proto
import (
timestamp "github.com/golang/protobuf/ptypes/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
)
@@ -611,7 +611,7 @@ type ServerKeyResponse struct {
// Server's Wireguard public key
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
// Key expiration timestamp after which the key should be fetched again by the client
ExpiresAt *timestamp.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"`
ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"`
// Version of the Wiretrustee Management Service protocol
Version int32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"`
}
@@ -655,7 +655,7 @@ func (x *ServerKeyResponse) GetKey() string {
return ""
}
func (x *ServerKeyResponse) GetExpiresAt() *timestamp.Timestamp {
func (x *ServerKeyResponse) GetExpiresAt() *timestamppb.Timestamp {
if x != nil {
return x.ExpiresAt
}
@@ -1331,6 +1331,10 @@ type ProviderConfig struct {
DeviceAuthEndpoint string `protobuf:"bytes,5,opt,name=DeviceAuthEndpoint,proto3" json:"DeviceAuthEndpoint,omitempty"`
// TokenEndpoint is an endpoint to request auth token.
TokenEndpoint string `protobuf:"bytes,6,opt,name=TokenEndpoint,proto3" json:"TokenEndpoint,omitempty"`
// Scopes provides the scopes to be included in the token request
Scope string `protobuf:"bytes,7,opt,name=Scope,proto3" json:"Scope,omitempty"`
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool `protobuf:"varint,8,opt,name=UseIDToken,proto3" json:"UseIDToken,omitempty"`
}
func (x *ProviderConfig) Reset() {
@@ -1407,6 +1411,20 @@ func (x *ProviderConfig) GetTokenEndpoint() string {
return ""
}
func (x *ProviderConfig) GetScope() string {
if x != nil {
return x.Scope
}
return ""
}
func (x *ProviderConfig) GetUseIDToken() bool {
if x != nil {
return x.UseIDToken
}
return false
}
// Route represents a route.Route object
type Route struct {
state protoimpl.MessageState
@@ -2000,7 +2018,7 @@ var file_management_proto_rawDesc = []byte{
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76,
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72,
0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44,
0x10, 0x00, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43,
0x10, 0x00, 0x22, 0x90, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49,
0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49,
0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65,
@@ -2013,81 +2031,84 @@ var file_management_proto_rawDesc = []byte{
0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74,
0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b,
0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09,
0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22,
0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74,
0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77,
0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79,
0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20,
0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74,
0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69,
0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18,
0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64,
0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09,
0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e,
0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18,
0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f,
0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72,
0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f,
0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e,
0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58,
0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06,
0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18,
0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52,
0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70,
0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04,
0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65,
0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52,
0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20,
0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74,
0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0x7f,
0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75,
0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73,
0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b,
0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50,
0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72,
0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73,
0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x22,
0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a,
0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a,
0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e,
0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20,
0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32, 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12,
0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73,
0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c,
0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12,
0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f,
0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44,
0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12,
0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12,
0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74,
0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b,
0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50,
0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12,
0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52,
0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75,
0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73,
0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44,
0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01,
0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01,
0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c,
0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47,
0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65,
0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75,
0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32,
0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73,
0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a,
0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f,
0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65,
0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52,
0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74,
0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12,
0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61,
0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03,
0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18,
0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03,
0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14,
0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52,
0x44, 0x61, 0x74, 0x61, 0x22, 0x7f, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76,
0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d,
0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65,
0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01,
0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44,
0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x73, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20,
0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50,
0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32,
0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72,
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d,
0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42,
0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11,
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74,
0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12,
0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70,
0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65,
0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f,
0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x33,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04,
0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61,
0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65,
0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65,
0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a,
0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72,
0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61,
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -2132,7 +2153,7 @@ var file_management_proto_goTypes = []interface{}{
(*SimpleRecord)(nil), // 24: management.SimpleRecord
(*NameServerGroup)(nil), // 25: management.NameServerGroup
(*NameServer)(nil), // 26: management.NameServer
(*timestamp.Timestamp)(nil), // 27: google.protobuf.Timestamp
(*timestamppb.Timestamp)(nil), // 27: google.protobuf.Timestamp
}
var file_management_proto_depIdxs = []int32{
11, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig

View File

@@ -246,6 +246,10 @@ message ProviderConfig {
string DeviceAuthEndpoint = 5;
// TokenEndpoint is an endpoint to request auth token.
string TokenEndpoint = 6;
// Scopes provides the scopes to be included in the token request
string Scope = 7;
// UseIDToken indicates if the id token should be used for authentication
bool UseIDToken = 8;
}
// Route represents a route.Route object

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/crc32"
"math/rand"
@@ -54,7 +55,8 @@ type AccountManager interface {
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error)
@@ -66,8 +68,10 @@ type AccountManager interface {
GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error)
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error
DeletePAT(accountID string, userID string, tokenID string) error
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(peerID string, sshKey string) error
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
@@ -1119,45 +1123,85 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
return nil
}
// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
unlock := am.Store.AcquireGlobalLock()
user, err := am.Store.GetUserByTokenID(tokenID)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(user.Id)
if err != nil {
return err
}
unlock()
unlock = am.Store.AcquireAccountLock(account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = time.Now().UTC()
return am.Store.SaveAccount(account)
}
// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, error) {
func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) {
if len(token) != PATLength {
return nil, nil, fmt.Errorf("token has wrong length")
return nil, nil, nil, fmt.Errorf("token has wrong length")
}
prefix := token[:len(PATPrefix)]
if prefix != PATPrefix {
return nil, nil, fmt.Errorf("token has wrong prefix")
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
}
secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return nil, nil, fmt.Errorf("token checksum does not match")
return nil, nil, nil, fmt.Errorf("token checksum does not match")
}
hashedToken := sha256.Sum256([]byte(token))
tokenID, err := am.Store.GetTokenIDByHashedToken(string(hashedToken[:]))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
tokenID, err := am.Store.GetTokenIDByHashedToken(encodedHashedToken)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
user, err := am.Store.GetUserByTokenID(tokenID)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
account, err := am.Store.GetAccountByUser(user.Id)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
return account, user, nil
pat := user.PATs[tokenID]
if pat == nil {
return nil, nil, nil, fmt.Errorf("personal access token not found")
}
return account, user, pat, nil
}
// GetAccountFromToken returns an account associated with this token

View File

@@ -2,6 +2,7 @@ package server
import (
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"net"
"reflect"
@@ -126,12 +127,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
Name: peerID1,
DNSLabel: peerID1,
Status: &PeerStatus{
LastSeen: time.Now(),
LastSeen: time.Now().UTC(),
Connected: false,
LoginExpired: true,
},
UserID: userID,
LastLogin: time.Now().Add(-time.Hour * 24 * 30 * 30),
LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
},
"peer-2": {
ID: peerID2,
@@ -140,12 +141,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
Name: peerID2,
DNSLabel: peerID2,
Status: &PeerStatus{
LastSeen: time.Now(),
LastSeen: time.Now().UTC(),
Connected: false,
LoginExpired: false,
},
UserID: userID,
LastLogin: time.Now(),
LastLogin: time.Now().UTC(),
LoginExpirationEnabled: true,
},
},
@@ -164,12 +165,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
Name: peerID1,
DNSLabel: peerID1,
Status: &PeerStatus{
LastSeen: time.Now(),
LastSeen: time.Now().UTC(),
Connected: false,
LoginExpired: true,
},
UserID: userID,
LastLogin: time.Now().Add(-time.Hour * 24 * 30 * 30),
LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
LoginExpirationEnabled: true,
},
"peer-2": {
@@ -179,12 +180,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
Name: peerID2,
DNSLabel: peerID2,
Status: &PeerStatus{
LastSeen: time.Now(),
LastSeen: time.Now().UTC(),
Connected: false,
LoginExpired: true,
},
UserID: userID,
LastLogin: time.Now().Add(-time.Hour * 24 * 30 * 30),
LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
LoginExpirationEnabled: true,
},
},
@@ -465,12 +466,13 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
"pat1": {
"tokenId": {
ID: "tokenId",
HashedToken: string(hashedToken[:]),
HashedToken: encodedHashedToken,
},
},
}
@@ -483,13 +485,52 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
Store: store,
}
account, user, err := am.GetAccountFromPAT(token)
account, user, pat, err := am.GetAccountFromPAT(token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}
assert.Equal(t, "account_id", account.Id)
assert.Equal(t, "someUser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
}
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
LastUsed: time.Time{},
},
},
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
}
err = am.MarkPATUsed("tokenId")
if err != nil {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = am.Store.GetAccount("account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
}
func TestAccountManager_PrivateAccount(t *testing.T) {
@@ -1245,12 +1286,12 @@ func TestAccount_Copy(t *testing.T) {
PATs: map[string]*PersonalAccessToken{
"pat1": {
ID: "pat1",
Description: "First PAT",
Name: "First PAT",
HashedToken: "SoMeHaShEdToKeN",
ExpirationDate: time.Now().AddDate(0, 0, 7),
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
CreatedBy: "user1",
CreatedAt: time.Now(),
LastUsed: time.Now(),
CreatedAt: time.Now().UTC(),
LastUsed: time.Now().UTC(),
},
},
},
@@ -1528,22 +1569,22 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
ID: "peer-1",
LoginExpirationEnabled: true,
Status: &PeerStatus{
LastSeen: time.Now(),
LastSeen: time.Now().UTC(),
Connected: true,
LoginExpired: false,
},
LastLogin: time.Now().Add(-30 * time.Minute),
LastLogin: time.Now().UTC().Add(-30 * time.Minute),
UserID: userID,
},
"peer-2": {
ID: "peer-2",
LoginExpirationEnabled: true,
Status: &PeerStatus{
LastSeen: time.Now(),
LastSeen: time.Now().UTC(),
Connected: true,
LoginExpired: false,
},
LastLogin: time.Now().Add(-2 * time.Hour),
LastLogin: time.Now().UTC().Add(-2 * time.Hour),
UserID: userID,
},
@@ -1551,11 +1592,11 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
ID: "peer-3",
LoginExpirationEnabled: true,
Status: &PeerStatus{
LastSeen: time.Now(),
LastSeen: time.Now().UTC(),
Connected: true,
LoginExpired: false,
},
LastLogin: time.Now().Add(-1 * time.Hour),
LastLogin: time.Now().UTC().Add(-1 * time.Hour),
UserID: userID,
},
},
@@ -1756,7 +1797,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
LoginExpired: false,
},
LoginExpirationEnabled: true,
LastLogin: time.Now(),
LastLogin: time.Now().UTC(),
UserID: userID,
},
"peer-2": {

View File

@@ -83,6 +83,10 @@ const (
AccountPeerLoginExpirationDisabled
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
AccountPeerLoginExpirationDurationUpdated
// PersonalAccessTokenCreated indicates that a user created a personal access token
PersonalAccessTokenCreated
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
PersonalAccessTokenDeleted
)
const (
@@ -168,6 +172,10 @@ const (
AccountPeerLoginExpirationDisabledMessage string = "Peer login expiration disabled for the account"
// AccountPeerLoginExpirationDurationUpdatedMessage is a human-readable text message of the AccountPeerLoginExpirationDurationUpdated activity
AccountPeerLoginExpirationDurationUpdatedMessage string = "Peer login expiration duration updated"
// PersonalAccessTokenCreatedMessage is a human-readable text message of the PersonalAccessTokenCreated activity
PersonalAccessTokenCreatedMessage string = "Personal access token created"
// PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity
PersonalAccessTokenDeletedMessage string = "Personal access token deleted"
)
// Activity that triggered an Event
@@ -258,6 +266,10 @@ func (a Activity) Message() string {
return AccountPeerLoginExpirationDisabledMessage
case AccountPeerLoginExpirationDurationUpdated:
return AccountPeerLoginExpirationDurationUpdatedMessage
case PersonalAccessTokenCreated:
return PersonalAccessTokenCreatedMessage
case PersonalAccessTokenDeleted:
return PersonalAccessTokenDeletedMessage
default:
return "UNKNOWN_ACTIVITY"
}
@@ -348,6 +360,10 @@ func (a Activity) StringCode() string {
return "account.setting.peer.login.expiration.enable"
case AccountPeerLoginExpirationDisabled:
return "account.setting.peer.login.expiration.disable"
case PersonalAccessTokenCreated:
return "personal.access.token.create"
case PersonalAccessTokenDeleted:
return "personal.access.token.delete"
default:
return "UNKNOWN_ACTIVITY"
}

View File

@@ -2,10 +2,12 @@ package sqlite
import (
"fmt"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/stretchr/testify/assert"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity"
)
func TestNewSQLiteStore(t *testing.T) {
@@ -15,13 +17,13 @@ func TestNewSQLiteStore(t *testing.T) {
t.Fatal(err)
return
}
defer store.Close() //nolint
defer store.Close() //nolint
accountID := "account_1"
for i := 0; i < 10; i++ {
_, err = store.Save(&activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user_" + fmt.Sprint(i),
TargetID: "peer_" + fmt.Sprint(i),

View File

@@ -24,6 +24,11 @@ const (
NONE Provider = "none"
)
const (
// DefaultDeviceAuthFlowScope defines the bare minimum scope to request in the device authorization flow
DefaultDeviceAuthFlowScope string = "openid"
)
// Config of the Management service
type Config struct {
Stuns []*Host
@@ -39,6 +44,17 @@ type Config struct {
DeviceAuthorizationFlow *DeviceAuthorizationFlow
}
// GetAuthAudiences returns the audience from the http config and device authorization flow config
func (c Config) GetAuthAudiences() []string {
audiences := []string{c.HttpConfig.AuthAudience}
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
}
return audiences
}
// TURNConfig is a config of the TURNCredentialsManager
type TURNConfig struct {
TimeBasedCredentials bool
@@ -98,6 +114,10 @@ type ProviderConfig struct {
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
}
// validateURL validates input http url

View File

@@ -2,9 +2,11 @@ package server
import (
"fmt"
"github.com/netbirdio/netbird/management/server/activity"
log "github.com/sirupsen/logrus"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
)
// GetEvents returns a list of activity events of an account
@@ -39,7 +41,7 @@ func (am *DefaultAccountManager) storeEvent(initiatorID, targetID, accountID str
go func() {
_, err := am.eventStore.Save(&activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activityID,
InitiatorID: initiatorID,
TargetID: targetID,
@@ -47,7 +49,7 @@ func (am *DefaultAccountManager) storeEvent(initiatorID, targetID, accountID str
Meta: meta,
})
if err != nil {
//todo add metric
// todo add metric
log.Errorf("received an error while storing an activity event, error: %s", err)
}
}()

View File

@@ -1,17 +1,19 @@
package server
import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/stretchr/testify/assert"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity"
)
func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ activity.Activity, initiatorID, targetID,
accountID string, count int) {
for i := 0; i < count; i++ {
_, err := manager.eventStore.Save(&activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: typ,
InitiatorID: initiatorID,
TargetID: targetID,

View File

@@ -112,7 +112,7 @@ func restore(file string) (*FileStore, error) {
store.UserID2AccountID[user.Id] = accountID
for _, pat := range user.PATs {
store.TokenID2UserID[pat.ID] = user.Id
store.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID
store.HashedPAT2TokenID[pat.HashedToken] = pat.ID
}
}
@@ -121,15 +121,25 @@ func restore(file string) (*FileStore, error) {
store.PrivateDomain2AccountID[account.Domain] = accountID
}
// if no policies are defined, that means we need to migrate Rules to policies
if len(account.Policies) == 0 {
// TODO: policy query generated from the Go template and rule object.
// We need to refactor this part to avoid using templating for policies queries building
// and drop this migration part.
policies := make(map[string]int, len(account.Policies))
for i, policy := range account.Policies {
policies[policy.ID] = i
}
if account.Policies == nil {
account.Policies = make([]*Policy, 0)
for _, rule := range account.Rules {
policy, err := RuleToPolicy(rule)
if err != nil {
log.Errorf("unable to migrate rule to policy: %v", err)
continue
}
}
for _, rule := range account.Rules {
policy, err := RuleToPolicy(rule)
if err != nil {
log.Errorf("unable to migrate rule to policy: %v", err)
continue
}
if i, ok := policies[policy.ID]; ok {
account.Policies[i] = policy
} else {
account.Policies = append(account.Policies, policy)
}
}
@@ -163,7 +173,7 @@ func restore(file string) (*FileStore, error) {
for key, peer := range account.Peers {
// set LastLogin for the peers that were onboarded before the peer login expiration feature
if peer.LastLogin.IsZero() {
peer.LastLogin = time.Now()
peer.LastLogin = time.Now().UTC()
}
if peer.ID != "" {
continue
@@ -268,7 +278,7 @@ func (s *FileStore) SaveAccount(account *Account) error {
s.UserID2AccountID[user.Id] = accountCopy.Id
for _, pat := range user.PATs {
s.TokenID2UserID[pat.ID] = user.Id
s.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID
s.HashedPAT2TokenID[pat.HashedToken] = pat.ID
}
}

View File

@@ -95,7 +95,7 @@ func TestSaveAccount(t *testing.T) {
IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{},
Name: "peer name",
Status: &PeerStatus{Connected: true, LastSeen: time.Now()},
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
// SaveAccount should trigger persist
@@ -131,7 +131,7 @@ func TestStore(t *testing.T) {
IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{},
Name: "peer name",
Status: &PeerStatus{Connected: true, LastSeen: time.Now()},
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account.Groups["all"] = &Group{
ID: "all",
@@ -514,7 +514,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
}
// save status of non-existing peer
newStatus := PeerStatus{Connected: true, LastSeen: time.Now()}
newStatus := PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
@@ -526,7 +526,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{},
Name: "peer name",
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
Status: &PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(account)

View File

@@ -6,11 +6,10 @@ import (
"strings"
"time"
pb "github.com/golang/protobuf/proto" //nolint
pb "github.com/golang/protobuf/proto" // nolint
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/golang/protobuf/ptypes/timestamp"
@@ -33,7 +32,7 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager
config *Config
turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware
jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
}
@@ -47,12 +46,12 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return nil, err
}
var jwtMiddleware *middleware.JWTMiddleware
var jwtValidator *jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtMiddleware, err = middleware.NewJwtMiddleware(
jwtValidator, err = jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
@@ -88,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
accountManager: accountManager,
config: config,
turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware,
jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics,
}, nil
@@ -189,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) {
}
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
if s.jwtMiddleware == nil {
return "", status.Error(codes.Internal, "no jwt middleware set")
if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt validator set")
}
token, err := s.jwtMiddleware.ValidateAndParse(jwtToken)
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}
@@ -511,6 +510,8 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
},
}

View File

@@ -6,6 +6,8 @@ info:
tags:
- name: Users
description: Interact with and view information about users.
- name: Tokens
description: Interact with and view information about tokens.
- name: Peers
description: Interact with and view information about peers.
- name: Setup Keys
@@ -284,6 +286,59 @@ components:
- revoked
- auto_groups
- usage_limit
PersonalAccessToken:
type: object
properties:
id:
description: ID of a token
type: string
name:
description: Name of the token
type: string
expiration_date:
description: Date the token expires
type: string
format: date-time
created_by:
description: User ID of the user who created the token
type: string
created_at:
description: Date the token was created
type: string
format: date-time
last_used:
description: Date the token was last used
type: string
format: date-time
required:
- id
- name
- expiration_date
- created_by
- created_at
PersonalAccessTokenGenerated:
type: object
properties:
plain_token:
description: Plain text representation of the generated token
type: string
personal_access_token:
$ref: '#/components/schemas/PersonalAccessToken'
required:
- plain_token
- personal_access_token
PersonalAccessTokenRequest:
type: object
properties:
name:
description: Name of the token
type: string
expires_in:
description: Expiration in days
type: integer
required:
- name
- expires_in
GroupMinimum:
type: object
properties:
@@ -848,6 +903,133 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/{userId}/tokens:
get:
summary: Returns a list of all tokens for a user
tags: [ Tokens ]
security:
- BearerAuth: []
parameters:
- in: path
name: userId
required: true
schema:
type: string
description: The User ID
responses:
'200':
description: A JSON Array of PersonalAccessTokens
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/PersonalAccessToken'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a new token
tags: [ Tokens ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: userId
required: true
schema:
type: string
description: The User ID
requestBody:
description: PersonalAccessToken create parameters
content:
application/json:
schema:
$ref: '#/components/schemas/PersonalAccessTokenRequest'
responses:
'200':
description: The token in plain text
content:
application/json:
schema:
$ref: '#/components/schemas/PersonalAccessTokenGenerated'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/{userId}/tokens/{tokenId}:
get:
summary: Returns a specific token
tags: [ Tokens ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: userId
required: true
schema:
type: string
description: The User ID
- in: path
name: tokenId
required: true
schema:
type: string
description: The Token ID
responses:
'200':
description: A PersonalAccessTokens Object
content:
application/json:
schema:
$ref: '#/components/schemas/PersonalAccessToken'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a token
tags: [ Tokens ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: userId
required: true
schema:
type: string
description: The User ID
- in: path
name: tokenId
required: true
schema:
type: string
description: The Token ID
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers:
get:
summary: Returns a list of all peers

View File

@@ -379,6 +379,44 @@ type PeerMinimum struct {
Name string `json:"name"`
}
// PersonalAccessToken defines model for PersonalAccessToken.
type PersonalAccessToken struct {
// CreatedAt Date the token was created
CreatedAt time.Time `json:"created_at"`
// CreatedBy User ID of the user who created the token
CreatedBy string `json:"created_by"`
// ExpirationDate Date the token expires
ExpirationDate time.Time `json:"expiration_date"`
// Id ID of a token
Id string `json:"id"`
// LastUsed Date the token was last used
LastUsed *time.Time `json:"last_used,omitempty"`
// Name Name of the token
Name string `json:"name"`
}
// PersonalAccessTokenGenerated defines model for PersonalAccessTokenGenerated.
type PersonalAccessTokenGenerated struct {
PersonalAccessToken PersonalAccessToken `json:"personal_access_token"`
// PlainToken Plain text representation of the generated token
PlainToken string `json:"plain_token"`
}
// PersonalAccessTokenRequest defines model for PersonalAccessTokenRequest.
type PersonalAccessTokenRequest struct {
// ExpiresIn Expiration in days
ExpiresIn int `json:"expires_in"`
// Name Name of the token
Name string `json:"name"`
}
// Policy defines model for Policy.
type Policy struct {
// Description Policy friendly description
@@ -808,3 +846,6 @@ type PostApiUsersJSONRequestBody = UserCreateRequest
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
type PutApiUsersIdJSONRequestBody = UserRequest
// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType.
type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest

View File

@@ -54,7 +54,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
ID := uint64(1)
events := make([]*activity.Event, 0)
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
ID: ID,
InitiatorID: userID,
@@ -64,7 +64,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.UserJoined,
ID: ID,
InitiatorID: userID,
@@ -74,7 +74,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.GroupCreated,
ID: ID,
InitiatorID: userID,
@@ -84,7 +84,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.SetupKeyUpdated,
ID: ID,
InitiatorID: userID,
@@ -94,7 +94,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.SetupKeyUpdated,
ID: ID,
InitiatorID: userID,
@@ -104,7 +104,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.SetupKeyRevoked,
ID: ID,
InitiatorID: userID,
@@ -114,7 +114,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.SetupKeyOverused,
ID: ID,
InitiatorID: userID,
@@ -124,7 +124,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.SetupKeyCreated,
ID: ID,
InitiatorID: userID,
@@ -134,7 +134,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.RuleAdded,
ID: ID,
InitiatorID: userID,
@@ -144,7 +144,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.RuleRemoved,
ID: ID,
InitiatorID: userID,
@@ -154,7 +154,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.RuleUpdated,
ID: ID,
InitiatorID: userID,
@@ -164,7 +164,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
})
ID++
events = append(events, &activity.Event{
Timestamp: time.Now(),
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedWithSetupKey,
ID: ID,
InitiatorID: userID,

View File

@@ -300,7 +300,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
return
}
util.WriteJSONObject(w, "")
util.WriteJSONObject(w, emptyObject{})
}
// GetGroup returns a group

View File

@@ -8,6 +8,7 @@ import (
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -25,15 +26,17 @@ type apiHandler struct {
AuthCfg AuthCfg
}
// EmptyObject is an empty struct used to return empty JSON object
type emptyObject struct {
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware(
authCfg.Issuer,
authCfg.Audience,
authCfg.KeysLocation)
if err != nil {
return nil, err
}
func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
authCfg.Audience)
corsMiddleware := cors.AllowAll()
@@ -46,7 +49,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
metricsMiddleware := appMetrics.HTTPMiddleware()
router := rootRouter.PathPrefix("/api").Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
api := apiHandler{
Router: router,
@@ -57,6 +60,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
api.addAccountsEndpoint()
api.addPeersEndpoint()
api.addUsersEndpoint()
api.addUsersTokensEndpoint()
api.addSetupKeysEndpoint()
api.addRulesEndpoint()
api.addPoliciesEndpoint()
@@ -66,7 +70,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
api.addDNSSettingEndpoint()
api.addEventsEndpoint()
err = api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return err
@@ -110,6 +114,14 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addSetupKeysEndpoint() {
keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS")

View File

@@ -2,6 +2,9 @@ package middleware
import (
"net/http"
"regexp"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
@@ -39,10 +42,22 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
}
if !ok {
switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
if err != nil {
log.Debugf("Regex failed")
util.WriteError(status.Errorf(status.Internal, ""), w)
return
}
if ok {
log.Debugf("Valid Path")
h.ServeHTTP(w, r)
return
}
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
return
}

View File

@@ -0,0 +1,173 @@
package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
audience string
}
const (
userProperty = "user"
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware {
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
audience: audience,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := auth[0]
switch strings.ToLower(authType) {
case "bearer":
err := m.CheckJWTFromRequest(w, r)
if err != nil {
log.Debugf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
err := m.CheckPATFromRequest(w, r)
if err != nil {
log.Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
default:
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
})
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromJWTRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(token)
if err != nil {
return err
}
if validatedToken == nil {
return nil
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromPATRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.ExpirationDate) {
return fmt.Errorf("token expired")
}
err = m.markPATUsed(pat.ID)
if err != nil {
return err
}
claimMaps := jwt.MapClaims{}
claimMaps[jwtclaims.UserIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts
// the PAT token from the Authorization header.
func getTokenFromPATRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
return "", errors.New("Authorization header format must be Token {token}")
}
return authHeaderParts[1], nil
}

View File

@@ -0,0 +1,123 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
)
const (
audience = "audience"
accountID = "accountID"
domain = "domain"
userID = "userID"
tokenID = "tokenID"
PAT = "PAT"
JWT = "JWT"
wrongToken = "wrongToken"
)
var testAccount = &server.Account{
Id: accountID,
Domain: domain,
Users: map[string]*server.User{
userID: {
Id: userID,
PATs: map[string]*server.PersonalAccessToken{
tokenID: {
ID: tokenID,
Name: "My first token",
HashedToken: "someHash",
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
CreatedBy: userID,
CreatedAt: time.Now().UTC(),
LastUsed: time.Now().UTC(),
},
},
},
},
}
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{}, nil
}
return nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(token string) error {
if token == tokenID {
return nil
}
return fmt.Errorf("Should never get reached")
}
func TestAuthMiddleware_Handler(t *testing.T) {
tt := []struct {
name string
authHeader string
expectedStatusCode int
}{
{
name: "Valid PAT Token",
authHeader: "Token " + PAT,
expectedStatusCode: 200,
},
{
name: "Invalid PAT Token",
authHeader: "Token " + wrongToken,
expectedStatusCode: 401,
},
{
name: "Valid JWT Token",
authHeader: "Bearer " + JWT,
expectedStatusCode: 200,
},
{
name: "Invalid JWT Token",
authHeader: "Bearer " + wrongToken,
expectedStatusCode: 401,
},
{
name: "Basic Auth",
authHeader: "Basic " + PAT,
expectedStatusCode: 401,
},
}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// do nothing
})
authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience)
handlerToTest := authMiddleware.Handler(nextHandler)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://testing", nil)
req.Header.Set("Authorization", tc.authHeader)
rec := httptest.NewRecorder()
handlerToTest.ServeHTTP(rec, req)
if rec.Result().StatusCode != tc.expectedStatusCode {
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode)
}
})
}
}

View File

@@ -1,254 +0,0 @@
package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
)
// A function called whenever an error is encountered
type errorHandler func(w http.ResponseWriter, r *http.Request, err string)
// TokenExtractor is a function that takes a request as input and returns
// either a token or an error. An error should only be returned if an attempt
// to specify a token was found, but the information was somehow incorrectly
// formed. In the case where a token is simply not present, this should not
// be treated as an error. An empty string should be returned in that case.
type TokenExtractor func(r *http.Request) (string, error)
// Options is a struct for specifying configuration options for the middleware.
type Options struct {
// The function that will return the Key to validate the JWT.
// It can be either a shared secret or a public key.
// Default value: nil
ValidationKeyGetter jwt.Keyfunc
// The name of the property in the request where the user information
// from the JWT will be stored.
// Default value: "user"
UserProperty string
// The function that will be called when there's an error validating the token
// Default value:
ErrorHandler errorHandler
// A boolean indicating if the credentials are required or not
// Default value: false
CredentialsOptional bool
// A function that extracts the token from the request
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
Extractor TokenExtractor
// Debug flag turns on debugging output
// Default: false
Debug bool
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
// Default: nil
SigningMethod jwt.SigningMethod
}
type JWTMiddleware struct {
Options Options
}
func OnError(w http.ResponseWriter, r *http.Request, err string) {
util.WriteError(status.Errorf(status.Unauthorized, ""), w)
}
// New constructs a new Secure instance with supplied options.
func New(options ...Options) *JWTMiddleware {
var opts Options
if len(options) == 0 {
opts = Options{}
} else {
opts = options[0]
}
if opts.UserProperty == "" {
opts.UserProperty = "user"
}
if opts.ErrorHandler == nil {
opts.ErrorHandler = OnError
}
if opts.Extractor == nil {
opts.Extractor = FromAuthHeader
}
return &JWTMiddleware{
Options: opts,
}
}
func (m *JWTMiddleware) logf(format string, args ...interface{}) {
if m.Options.Debug {
log.Printf(format, args...)
}
}
// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere.
func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
err := m.CheckJWTFromRequest(w, r)
// If there was an error, do not call next.
if err == nil && next != nil {
next(w, r)
}
}
func (m *JWTMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Let secure process the request. If it returns an error,
// that indicates the request should not continue.
err := m.CheckJWTFromRequest(w, r)
// If there was an error, do not continue.
if err != nil {
log.Errorf("received an error while validating the JWT token: %s. "+
"Review your IDP configuration and ensure that "+
"settings are in sync between dashboard and management", err)
return
}
h.ServeHTTP(w, r)
})
}
// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts
// the JWT token from the Authorization header.
func FromAuthHeader(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// FromParameter returns a function that extracts the token from the specified
// query string parameter
func FromParameter(param string) TokenExtractor {
return func(r *http.Request) (string, error) {
return r.URL.Query().Get(param), nil
}
}
// FromFirst returns a function that runs multiple token extractors and takes the
// first token it finds
func FromFirst(extractors ...TokenExtractor) TokenExtractor {
return func(r *http.Request) (string, error) {
for _, ex := range extractors {
token, err := ex(r)
if err != nil {
return "", err
}
if token != "" {
return token, nil
}
}
return "", nil
}
}
func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
if !m.Options.EnableAuthOnOptions {
if r.Method == "OPTIONS" {
return nil
}
}
// Use the specified token extractor to extract a token from the request
token, err := m.Options.Extractor(r)
// If debugging is turned on, log the outcome
if err != nil {
m.logf("Error extracting JWT: %v", err)
} else {
m.logf("Token extracted: %s", token)
}
// If an error occurs, call the error handler and return an error
if err != nil {
m.Options.ErrorHandler(w, r, err.Error())
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.ValidateAndParse(token)
if err != nil {
m.Options.ErrorHandler(w, r, err.Error())
return err
}
if validatedToken == nil {
return nil
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// ValidateAndParse validates and parses a given access token against jwt standards and signing methods
func (m *JWTMiddleware) ValidateAndParse(token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.Options.CredentialsOptional {
m.logf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
m.logf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
}
// Now parse the token
parsedToken, err := jwt.Parse(token, m.Options.ValidationKeyGetter)
// Check if there was an error in parsing...
if err != nil {
m.logf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != parsedToken.Header["alg"] {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.Options.SigningMethod.Alg(),
parsedToken.Header["alg"])
m.logf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
m.logf(errorMsg)
return nil, errors.New(errorMsg)
}
return parsedToken, nil
}

View File

@@ -243,7 +243,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
return
}
util.WriteJSONObject(w, "")
util.WriteJSONObject(w, emptyObject{})
}
// GetNameserverGroup handles a nameserver group Get request identified by ID

View File

@@ -0,0 +1,178 @@
package http
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// PATHandler is the nameserver group handler of the account
type PATHandler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewPATsHandler creates a new PATHandler HTTP handler
func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler {
return &PATHandler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
userID := vars["userId"]
if len(userID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID)
if err != nil {
util.WriteError(err, w)
return
}
var patResponse []*api.PersonalAccessToken
for _, pat := range pats {
patResponse = append(patResponse, toPATResponse(pat))
}
util.WriteJSONObject(w, patResponse)
}
// GetToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
tokenID := vars["tokenId"]
if len(tokenID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return
}
pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toPATResponse(pat))
}
// CreateToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
var req api.PostApiUsersUserIdTokensJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toPATGeneratedResponse(pat))
}
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
tokenID := vars["tokenId"]
if len(tokenID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return
}
err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
}
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
var lastUsed *time.Time
if !pat.LastUsed.IsZero() {
lastUsed = &pat.LastUsed
}
return &api.PersonalAccessToken{
CreatedAt: pat.CreatedAt,
CreatedBy: pat.CreatedBy,
Name: pat.Name,
ExpirationDate: pat.ExpirationDate,
Id: pat.ID,
LastUsed: lastUsed,
}
}
func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated {
return &api.PersonalAccessTokenGenerated{
PlainToken: pat.PlainToken,
PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken),
}
}

View File

@@ -0,0 +1,254 @@
package http
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
)
const (
existingAccountID = "existingAccountID"
notFoundAccountID = "notFoundAccountID"
existingUserID = "existingUserID"
notFoundUserID = "notFoundUserID"
existingTokenID = "existingTokenID"
notFoundTokenID = "notFoundTokenID"
domain = "hotmail.com"
)
var testAccount = &server.Account{
Id: existingAccountID,
Domain: domain,
Users: map[string]*server.User{
existingUserID: {
Id: existingUserID,
PATs: map[string]*server.PersonalAccessToken{
existingTokenID: {
ID: existingTokenID,
Name: "My first token",
HashedToken: "someHash",
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
CreatedBy: existingUserID,
CreatedAt: time.Now().UTC(),
LastUsed: time.Now().UTC(),
},
"token2": {
ID: "token2",
Name: "My second token",
HashedToken: "someOtherHash",
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
CreatedBy: existingUserID,
CreatedAt: time.Now().UTC(),
LastUsed: time.Now().UTC(),
},
},
},
},
}
func initPATTestData() *PATHandler {
return &PATHandler{
accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
return &server.PersonalAccessTokenGenerated{
PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe",
PersonalAccessToken: server.PersonalAccessToken{},
}, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testAccount, testAccount.Users[existingUserID], nil
},
DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
if tokenID != existingTokenID {
return status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
}
return nil
},
GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
if tokenID != existingTokenID {
return nil, status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
}
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
},
GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: domain,
AccountId: testNSGroupAccountID,
}
}),
),
}
}
func TestTokenHandlers(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "Get All Tokens",
requestType: http.MethodGet,
requestPath: "/api/users/" + existingUserID + "/tokens",
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Get Existing Token",
requestType: http.MethodGet,
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Get Not Existing Token",
requestType: http.MethodGet,
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
expectedStatus: http.StatusNotFound,
},
{
name: "Delete Existing Token",
requestType: http.MethodDelete,
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
expectedStatus: http.StatusOK,
},
{
name: "Delete Not Existing Token",
requestType: http.MethodDelete,
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
expectedStatus: http.StatusNotFound,
},
{
name: "POST OK",
requestType: http.MethodPost,
requestPath: "/api/users/" + existingUserID + "/tokens",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"expires_in\":7}")),
expectedStatus: http.StatusOK,
expectedBody: true,
},
}
p := initPATTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
switch tc.name {
case "POST OK":
got := &api.PersonalAccessTokenGenerated{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.NotEmpty(t, got.PlainToken)
assert.Equal(t, server.PATLength, len(got.PlainToken))
case "Get All Tokens":
expectedTokens := []api.PersonalAccessToken{
toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]),
toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]),
}
var got []api.PersonalAccessToken
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.True(t, cmp.Equal(got, expectedTokens))
case "Get Existing Token":
expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID])
got := &api.PersonalAccessToken{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.True(t, cmp.Equal(*got, expectedToken))
}
})
}
}
func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken {
return api.PersonalAccessToken{
Id: serverToken.ID,
Name: serverToken.Name,
CreatedAt: serverToken.CreatedAt,
LastUsed: &serverToken.LastUsed,
CreatedBy: serverToken.CreatedBy,
ExpirationDate: serverToken.ExpirationDate,
}
}

View File

@@ -66,7 +66,7 @@ func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w htt
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, "")
util.WriteJSONObject(w, emptyObject{})
}
// HandlePeer handles all peer requests for GET, PUT and DELETE operations

Some files were not shown because too many files have changed in this diff Show More