Compare commits

...

75 Commits

Author SHA1 Message Date
Zoltan Papp
0ab0b834eb Fix 2024-12-19 14:01:52 +01:00
Zoltán Papp
9ff03141ba Add peer conn profile code 2024-12-19 13:53:51 +01:00
Maycon Santos
37ad370344 [client] Avoid using iota on mixed const block (#3057)
Used the values as resolved when the first iota value was the second const in the block.
2024-12-16 18:09:31 +01:00
VYSE V.E.O
703647da1e fix client unsupported h2 protocol when only 443 activated (#3009)
When I remove 80 http port in Caddyfile, netbird client cannot connect server:443. Logs show error below:
{"level":"debug","ts":1733809631.4012625,"logger":"http.stdlib","msg":"http: TLS handshake error from redacted:41580: tls: client requested unsupported application protocols ([h2])"}
I wonder here h2 protocol is absent.
2024-12-16 14:17:46 +01:00
Maycon Santos
9eff58ae62 Upgrade x/crypto package (#3055)
Mitigates the CVE-2024-45337
2024-12-16 10:30:41 +01:00
Jesse R Codling
3844516aa7 [client] fix: reformat IPv6 ICE addresses when punching (#3050)
Should fix #2327 and #2606 by checking for IPv6 addresses from ICE
2024-12-16 09:58:54 +01:00
M. Essam
f591e47404 Handle DNF5 install script (#3026) 2024-12-16 09:41:36 +01:00
Maycon Santos
287ae81195 [misc] split tests with management and rest (#3051)
optimize go cache for tests
2024-12-14 21:18:46 +01:00
M. Essam
a4a30744ad Fix race condition with systray ready (#2993) 2024-12-14 12:17:53 -08:00
Maycon Santos
dcba6a6b7e fix: client/Dockerfile to reduce vulnerabilities (#3019)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE320-OPENSSL-8235201
- https://snyk.io/vuln/SNYK-ALPINE320-OPENSSL-8235201

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2024-12-11 16:46:51 +01:00
Pascal Fischer
6142828a9c [management] restructure api files (#3013) 2024-12-10 15:59:25 +01:00
Bethuel Mmbaga
97bb74f824 Remove peer login log (#3005)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-09 18:40:06 +01:00
Maycon Santos
2147bf75eb [client] Add peer conn init limit (#3001)
Limit the peer connection initialization to 200 peers at the same time
2024-12-09 17:10:31 +01:00
Pascal Fischer
e40a29ba17 [client] Add support for state manager on iOS (#2996) 2024-12-06 16:51:42 +01:00
Edouard Vanbelle
ff330e644e upgrade zcalusic/sysinfo@v1.1.3 (add serial for ARM arch) (#2954)
Signed-off-by: Edouard Vanbelle <edouard.vanbelle@shadow.tech>
2024-12-05 15:38:00 +01:00
M. Essam
713e320c4c Update account peers on login on meta change (#2991)
* Update account peers on login on meta change

* Factor out LoginPeer peer not found handling
2024-12-05 14:15:23 +01:00
Maycon Santos
e67fe89adb Reduce max wait time to initialize peer connections (#2984)
* Reduce max wait time to initialize peer connections

setting rand time range to 100-300ms instead of 100-800ms

* remove min wait time
2024-12-05 13:03:11 +01:00
Viktor Liu
6cfbb1f320 [client] Init route selector early (#2989) 2024-12-05 12:41:12 +01:00
Viktor Liu
c853011a32 [client] Don't return error in rule removal if protocol is not supported (#2990) 2024-12-05 12:28:35 +01:00
Maycon Santos
b50b89ba14 [client] Cleanup status resources on engine stop (#2981)
cleanup leftovers from status recorder when stopping the engine
2024-12-04 14:09:04 +01:00
Pascal Fischer
d063fbb8b9 [management] merge update account peers in sync call (#2978) 2024-12-03 16:41:19 +01:00
Viktor Liu
e5d42bc963 [client] Add state handling cmdline options (#2821) 2024-12-03 16:07:18 +01:00
Viktor Liu
8866394eb6 [client] Don't choke on non-existent interface in route updates (#2922) 2024-12-03 15:33:41 +01:00
Viktor Liu
17c20b45ce [client] Add network map to debug bundle (#2966) 2024-12-03 14:50:12 +01:00
Joakim Nohlgård
7dacd9cb23 [management] Add missing parentheses on iphone hostname generation condition (#2977) 2024-12-03 13:49:02 +01:00
Viktor Liu
6285e0d23e [client] Add netbird.err and netbird.out to debug bundle (#2971) 2024-12-03 12:43:17 +01:00
Maycon Santos
a4826cfb5f [client] Get static system info once (#2965)
Get static system info once for Windows, Darwin, and Linux nodes

This should improve startup and peer authentication times
2024-12-03 10:22:04 +01:00
Zoltan Papp
a0bf0bdcc0 Pass IP instead of net to Rosenpass (#2975) 2024-12-03 10:13:27 +01:00
Viktor Liu
dffce78a8c [client] Fix debug bundle state anonymization test (#2976) 2024-12-02 20:19:34 +01:00
Viktor Liu
c7e7ad5030 [client] Add state file to debug bundle (#2969) 2024-12-02 18:04:02 +01:00
Viktor Liu
5142dc52c1 [client] Persist route selection (#2810) 2024-12-02 17:55:02 +01:00
Zoltan Papp
ecb44ff306 [client] Add pprof build tag (#2964)
* Add pprof build tag

* Change env handling
2024-12-01 19:22:52 +01:00
victorserbu2709
e4a5fb3e91 Unspecified address: default NetworkTypeUDP4+NetworkTypeUDP6 (#2804) 2024-11-30 10:34:52 +01:00
v1rusnl
e52d352a48 Update Caddyfile and Docker Compose to support HTTP3 (#2822) 2024-11-30 10:26:31 +01:00
Maycon Santos
f9723c9266 [client] Account different policiy rules for routes firewall rules (#2939)
* Account different policies rules for routes firewall rules

This change ensures that route firewall rules will consider source group peers in the rules generation for access control policies.

This fixes the behavior where multiple policies with different levels of access was being applied to all peers in a distribution group

* split function

* avoid unnecessary allocation

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2024-11-29 17:50:35 +01:00
Maycon Santos
8efad1d170 Add guide when signing key is not found (#2942)
Some users face issues with their IdP due to signing key not being refreshed

With this change we advise users to configure key refresh

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* removing leftover

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2024-11-29 10:06:40 +01:00
Pascal Fischer
c6641be94b [tests] Enable benchmark tests on github actions (#2961) 2024-11-28 19:22:01 +01:00
Pascal Fischer
89cf8a55e2 [management] Add performance test for login and sync calls (#2960) 2024-11-28 14:59:53 +01:00
Pascal Fischer
00c3b67182 [management] refactor to use account object instead of separate db calls for peer update (#2957) 2024-11-28 11:13:01 +01:00
Zoltan Papp
9203690033 [client] Code cleaning in net pkg and fix exit node feature on Android(#2932)
Code cleaning around the util/net package. The goal was to write a more understandable source code but modify nothing on the logic.
Protect the WireGuard UDP listeners with marks.
The implementation can support the VPN permission revocation events in thread safe way. It will be important if we start to support the running time route and DNS update features.

- uniformize the file name convention: [struct_name] _ [functions] _ [os].go
- code cleaning in net_linux.go
- move env variables to env.go file
2024-11-26 23:34:27 +01:00
Bethuel Mmbaga
9683da54b0 [management] Refactor nameserver groups to use store methods (#2888) 2024-11-26 17:39:04 +01:00
Bethuel Mmbaga
0e48a772ff [management] Refactor DNS settings to use store methods (#2883) 2024-11-26 13:43:05 +01:00
Bethuel Mmbaga
f118d81d32 [management] Refactor policy to use store methods (#2878) 2024-11-26 10:46:05 +01:00
Bethuel Mmbaga
ca12bc6953 [management] Refactor posture check to use store methods (#2874) 2024-11-25 16:26:24 +01:00
Viktor Liu
9810386937 [client] Allow routing to fallback to exclusion routes if rules are not supported (#2909) 2024-11-25 15:19:56 +01:00
Viktor Liu
f1625b32bd [client] Set up sysctl and routing table name only if routing rules are available (#2933) 2024-11-25 15:12:16 +01:00
Viktor Liu
0ecd5f2118 [client] Test nftables for incompatible iptables rules (#2948) 2024-11-25 15:11:56 +01:00
Viktor Liu
940d0c48c6 [client] Don't return error in userspace mode without firewall (#2924) 2024-11-25 15:11:31 +01:00
Maycon Santos
56cecf849e Import time package (#2940) 2024-11-22 20:40:30 +01:00
Maycon Santos
05c4aa7c2c [misc] Renew slack link (#2938) 2024-11-22 18:50:47 +01:00
Zoltan Papp
2a5cb16494 [relay] Refactor initial Relay connection (#2800)
Can support firewalls with restricted WS rules

allow to run engine without Relay servers
keep up to date Relay address changes
2024-11-22 18:12:34 +01:00
Pascal Fischer
9db1932664 [management] Fix getSetupKey call (#2927) 2024-11-22 10:15:51 +01:00
Viktor Liu
1bbabf70b0 [client] Fix allow netbird rule verdict (#2925)
* Fix allow netbird rule verdict

* Fix chain name
2024-11-21 16:53:37 +01:00
Pascal Fischer
aa575d6f44 [management] Add activity events to group propagation flow (#2916) 2024-11-21 15:10:34 +01:00
Pascal Fischer
f66bbcc54c [management] Add metric for peer meta update (#2913) 2024-11-19 18:13:26 +01:00
Pascal Fischer
5dd6a08ea6 link peer meta update back to account object (#2911) 2024-11-19 17:25:49 +01:00
Krzysztof Nazarewski (kdn)
eb5d0569ae [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899)
* dialer: fix crash instead of returning error

* add NB_SKIP_SOCKET_MARK
2024-11-19 14:14:58 +01:00
Pascal Fischer
52ea2e84e9 [management] Add transaction metrics and exclude getAccount time from peers update (#2904) 2024-11-19 00:04:50 +01:00
Maycon Santos
78fab877c0 [misc] Update signing pipeline version (#2900) 2024-11-18 15:31:53 +01:00
Maycon Santos
65a94f695f use google domain for tests (#2902) 2024-11-18 12:55:02 +01:00
Kursat Aktas
ec543f89fb Introducing NetBird Guru on Gurubase.io (#2778) 2024-11-16 15:45:31 +01:00
Viktor Liu
a7d5c52203 Fix error state race on mgmt connection error (#2892) 2024-11-15 22:59:49 +01:00
Viktor Liu
582bb58714 Move state updates outside the refcounter (#2897) 2024-11-15 22:55:33 +01:00
Viktor Liu
121dfda915 [client] Fix state manager race conditions (#2890) 2024-11-15 20:05:26 +01:00
İsmail
a1c5287b7c Fix the Inactivity Expiration problem. (#2865) 2024-11-15 18:21:27 +01:00
Bethuel Mmbaga
12f442439a [management] Refactor group to use store methods (#2867)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor account peers update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor GetGroupByID and add NewGroupNotFoundError

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add AddPeer and RemovePeer methods to Group struct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preserve store engine in SqlStore transactions

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Change setup key log level to debug for missing group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve modified peers once for group events

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locking and merge group deletion methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-15 20:09:32 +03:00
Pascal Fischer
d9b691b8a5 [management] Limit the setup-key update operation (#2841) 2024-11-15 17:00:06 +01:00
Pascal Fischer
4aee3c9e33 [client/management] add peer lock to peer meta update and fix isEqual func (#2840) 2024-11-15 16:59:03 +01:00
Pascal Fischer
44e799c687 [management] Fix limited peer view groups (#2894) 2024-11-15 11:16:16 +01:00
Viktor Liu
be78efbd42 [client] Handle panic on nil wg interface (#2891) 2024-11-14 20:15:16 +01:00
Maycon Santos
6886691213 Update route calculation tests (#2884)
- Add two new test cases for p2p and relay routes with same latency
- Add extra statuses generation
2024-11-13 15:21:33 +01:00
Zoltan Papp
b48afd92fd [relay-server] Always close ws conn when work thread exit (#2879)
Close ws conn when work thread exit
2024-11-13 15:02:51 +01:00
Viktor Liu
39329e12a1 [client] Improve state write timeout and abort work early on timeout (#2882)
* Improve state write timeout and abort work early on timeout

* Don't block on initial persist state
2024-11-13 13:46:00 +01:00
Pascal Fischer
20a5afc359 [management] Add more logs to the peer update processes (#2881) 2024-11-12 14:19:22 +01:00
Bethuel Mmbaga
6cb697eed6 [management] Refactor setup key to use store methods (#2861)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove context from DB queries

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add user permission check and add setup events into events to store slice

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve all groups once during setup key auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 19:46:10 +03:00
176 changed files with 8944 additions and 2641 deletions

View File

@@ -21,6 +21,7 @@ jobs:
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: "1.23.x" go-version: "1.23.x"
cache: false
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -28,8 +29,9 @@ jobs:
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }} key: macos-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
macos-gotest-
macos-go- macos-go-
- name: Install libpcap - name: Install libpcap
@@ -42,4 +44,4 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -11,31 +11,115 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
build-cache:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
id: cache
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: steps.cache.outputs.cache-hit != 'true'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Build client
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 go build .
- name: Build client 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
- name: Build management
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 go build .
- name: Build management 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
- name: Build signal
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 go build .
- name: Build signal 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
- name: Build relay
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 go build .
- name: Build relay 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test: test:
needs: [build-cache]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: "1.23.x" go-version: "1.23.x"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -50,27 +134,134 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
test_management:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
benchmark:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m ./...
test_client_on_docker: test_client_on_docker:
needs: [ build-cache ]
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: "1.23.x" go-version: "1.23.x"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev

View File

@@ -24,6 +24,23 @@ jobs:
id: go id: go
with: with:
go-version: "1.23.x" go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-
${{ runner.os }}-go-
- name: Download wintun - name: Download wintun
uses: carlosperate/download-file-action@v2 uses: carlosperate/download-file-action@v2
@@ -42,11 +59,13 @@ jobs:
- run: choco install -y sysinternals --ignore-checksums - run: choco install -y sysinternals --ignore-checksums
- run: choco install -y mingw - run: choco install -y mingw
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output - name: test output
if: ${{ always() }} if: ${{ always() }}
run: Get-Content test-out.txt run: Get-Content test-out.txt

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.16" SIGN_PIPE_VER: "v0.0.17"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"

View File

@@ -17,8 +17,12 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a> </a>
</p> </p>
</div> </div>
@@ -30,7 +34,7 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a> Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
<br/> <br/>
</strong> </strong>

View File

@@ -1,4 +1,4 @@
FROM alpine:3.20 FROM alpine:3.21.0
RUN apk add --no-cache ca-certificates iptables ip6tables RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -12,6 +12,8 @@ import (
"strings" "strings"
) )
const anonTLD = ".domain"
type Anonymizer struct { type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string domainAnonymizer map[string]string
@@ -83,29 +85,39 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
} }
func (a *Anonymizer) AnonymizeDomain(domain string) string { func (a *Anonymizer) AnonymizeDomain(domain string) string {
if strings.HasSuffix(domain, "netbird.io") || baseDomain := domain
strings.HasSuffix(domain, "netbird.selfhosted") || hasDot := strings.HasSuffix(domain, ".")
strings.HasSuffix(domain, "netbird.cloud") || if hasDot {
strings.HasSuffix(domain, "netbird.stage") || baseDomain = domain[:len(domain)-1]
strings.HasSuffix(domain, ".domain") { }
if strings.HasSuffix(baseDomain, "netbird.io") ||
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
strings.HasSuffix(baseDomain, "netbird.cloud") ||
strings.HasSuffix(baseDomain, "netbird.stage") ||
strings.HasSuffix(baseDomain, anonTLD) {
return domain return domain
} }
parts := strings.Split(domain, ".") parts := strings.Split(baseDomain, ".")
if len(parts) < 2 { if len(parts) < 2 {
return domain return domain
} }
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1] baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseDomain] anonymized, ok := a.domainAnonymizer[baseForLookup]
if !ok { if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + ".domain" anonymizedBase := "anon-" + generateRandomString(5) + anonTLD
a.domainAnonymizer[baseDomain] = anonymizedBase a.domainAnonymizer[baseForLookup] = anonymizedBase
anonymized = anonymizedBase anonymized = anonymizedBase
} }
return strings.Replace(domain, baseDomain, anonymized, 1) result := strings.Replace(baseDomain, baseForLookup, anonymized, 1)
if hasDot {
result += "."
}
return result
} }
func (a *Anonymizer) AnonymizeURI(uri string) string { func (a *Anonymizer) AnonymizeURI(uri string) string {
@@ -152,9 +164,9 @@ func (a *Anonymizer) AnonymizeString(str string) string {
return str return str
} }
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes. // AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string { func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`) re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI) return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
} }
@@ -168,10 +180,10 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
parts := strings.Split(match, `"`) parts := strings.Split(match, `"`)
if len(parts) >= 2 { if len(parts) >= 2 {
domain := parts[1] domain := parts[1]
if strings.HasSuffix(domain, ".domain") { if strings.HasSuffix(domain, anonTLD) {
return match return match
} }
randomDomain := generateRandomString(10) + ".domain" randomDomain := generateRandomString(10) + anonTLD
return strings.Replace(match, domain, randomDomain, 1) return strings.Replace(match, domain, randomDomain, 1)
} }
return match return match

View File

@@ -67,18 +67,36 @@ func TestAnonymizeDomain(t *testing.T) {
`^anon-[a-zA-Z0-9]+\.domain$`, `^anon-[a-zA-Z0-9]+\.domain$`,
true, true,
}, },
{
"Domain with Trailing Dot",
"example.com.",
`^anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{ {
"Subdomain", "Subdomain",
"sub.example.com", "sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`, `^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true, true,
}, },
{
"Subdomain with Trailing Dot",
"sub.example.com.",
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{ {
"Protected Domain", "Protected Domain",
"netbird.io", "netbird.io",
`^netbird\.io$`, `^netbird\.io$`,
false, false,
}, },
{
"Protected Domain with Trailing Dot",
"netbird.io.",
`^netbird\.io.$`,
false,
},
} }
for _, tc := range tests { for _, tc := range tests {
@@ -140,8 +158,16 @@ func TestAnonymizeSchemeURI(t *testing.T) {
expect string expect string
}{ }{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`}, {"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`}, {"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, {"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
} }
for _, tc := range tests { for _, tc := range tests {

View File

@@ -3,6 +3,7 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -61,6 +62,15 @@ var forCmd = &cobra.Command{
RunE: runForDuration, RunE: runForDuration,
} }
var persistenceCmd = &cobra.Command{
Use: "persistence [on|off]",
Short: "Set network map memory persistence",
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
Example: " netbird debug persistence on",
Args: cobra.ExactArgs(1),
RunE: setNetworkMapPersistence,
}
func debugBundle(cmd *cobra.Command, _ []string) error { func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd)
if err != nil { if err != nil {
@@ -171,6 +181,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
// Enable network map persistence before bringing the service up
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} }
@@ -200,6 +217,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown { if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
@@ -219,6 +243,34 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
persistence := strings.ToLower(args[0])
if persistence != "on" && persistence != "off" {
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
}
client := proto.NewDaemonServiceClient(conn)
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: persistence == "on",
})
if err != nil {
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message())
}
cmd.Printf("Network map persistence set to: %s\n", persistence)
return nil
}
func getStatusOutput(cmd *cobra.Command) string { func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string var statusOutputString string
statusResp, err := getStatus(cmd.Context()) statusResp, err := getStatus(cmd.Context())

33
client/cmd/pprof.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

View File

@@ -155,6 +155,7 @@ func init() {
debugCmd.AddCommand(logCmd) debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd) logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd) debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+ `Sets external IPs maps between local addresses and interfaces.`+

181
client/cmd/state.go Normal file
View File

@@ -0,0 +1,181 @@
package cmd
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var (
allFlag bool
)
var stateCmd = &cobra.Command{
Use: "state",
Short: "Manage daemon state",
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
}
var stateListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all stored states",
Long: "Lists all registered states with their status and basic information.",
Example: " netbird state list",
RunE: stateList,
}
var stateCleanCmd = &cobra.Command{
Use: "clean [state-name]",
Short: "Clean stored states",
Long: `Clean specific state or all states. The daemon must not be running.
This will perform cleanup operations and remove the state.`,
Example: ` netbird state clean dns_state
netbird state clean --all`,
RunE: stateClean,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
var stateDeleteCmd = &cobra.Command{
Use: "delete [state-name]",
Short: "Delete stored states",
Long: `Delete specific state or all states from storage. The daemon must not be running.
This will remove the state without performing any cleanup operations.`,
Example: ` netbird state delete dns_state
netbird state delete --all`,
RunE: stateDelete,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
func init() {
rootCmd.AddCommand(stateCmd)
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
}
func stateList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
if err != nil {
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
}
cmd.Printf("\nStored states:\n\n")
for _, state := range resp.States {
cmd.Printf("- %s\n", state.Name)
}
return nil
}
func stateClean(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
}
if resp.CleanedStates == 0 {
cmd.Println("No states were cleaned")
return nil
}
if allFlag {
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
} else {
cmd.Printf("Successfully cleaned state %q\n", stateName)
}
return nil
}
func stateDelete(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
}
if resp.DeletedStates == 0 {
cmd.Println("No states were deleted")
return nil
}
if allFlag {
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
} else {
cmd.Printf("Successfully deleted state %q\n", stateName)
}
return nil
}

View File

@@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
} }
// persist early to ensure cleanup of chains // persist early to ensure cleanup of chains
if err := stateManager.PersistState(context.Background()); err != nil { go func() {
log.Errorf("failed to persist state: %v", err) if err := stateManager.PersistState(context.Background()); err != nil {
} log.Errorf("failed to persist state: %v", err)
}
}()
return nil return nil
} }

View File

@@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
return err return err
} }
s.ips = temp.IPs s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil return nil
} }
@@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
return err return err
} }
s.ipsets = temp.IPSets s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil return nil
} }

View File

@@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
} }
// persist early // persist early
if err := stateManager.PersistState(context.Background()); err != nil { go func() {
log.Errorf("failed to persist state: %v", err) if err := stateManager.PersistState(context.Background()); err != nil {
} log.Errorf("failed to persist state: %v", err)
}
}()
return nil return nil
} }
@@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain var chain *nftables.Chain
for _, c := range chains { for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward { if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c chain = c
break break
} }
@@ -274,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains { for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" { if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c) rules, err := m.rConn.GetRules(c.Table, c)
if err != nil { if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err) log.Errorf("get rules for chain %q: %v", c.Name, err)
@@ -349,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Verdict{}, &expr.Verdict{
Kind: expr.VerdictAccept,
},
}, },
UserData: []byte(allowNetbirdInputRuleID), UserData: []byte(allowNetbirdInputRuleID),
} }

View File

@@ -1,9 +1,11 @@
package nftables package nftables
import ( import (
"bytes"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os/exec"
"testing" "testing"
"time" "time"
@@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}) })
} }
} }
func runIptablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("iptables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
require.NoError(t, err, "iptables-save failed to run")
return stdout.String(), stderr.String()
}
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
// Check for any incompatibility warnings
require.NotContains(t,
stderr,
"incompatible",
"iptables-save produced compatibility warning. Full stderr: %s",
stderr,
)
// Verify standard tables are present
expectedTables := []string{
"*filter",
"*nat",
"*mangle",
}
for _, table := range expectedTables {
require.Contains(t,
stdout,
table,
"iptables-save output missing expected table: %s\nFull stdout: %s",
table,
stdout,
)
}
}
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Reset(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,
}
err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}

View File

@@ -1 +0,0 @@
package nftables

View File

@@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
// SetLegacyManagement doesn't need to be implemented for this manager // SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return nil
} }
return m.nativeFirewall.SetLegacyManagement(isLegacy) return m.nativeFirewall.SetLegacyManagement(isLegacy)
} }

View File

@@ -0,0 +1,12 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -162,12 +162,13 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType var networks []ice.NetworkType
switch { switch {
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
case addr.IP.To16() != nil: case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default: default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
} }

View File

@@ -218,3 +218,31 @@ func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
RxBytes: peer.ReceiveBytes, RxBytes: peer.ReceiveBytes,
}, nil }, nil
} }
func (c *KernelConfigurer) GetAllStat() (map[string]WGStats, error) {
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
stats := make(map[string]WGStats)
for _, peer := range wgDevice.Peers {
stats[peer.PublicKey.String()] = WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}
}
return stats, nil
}

View File

@@ -263,6 +263,52 @@ func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
}, nil }, nil
} }
func (t *WGUSPConfigurer) GetAllStat() (map[string]WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("ipc get: %w", err)
}
stats, err := parsePeerInfo(ipc, []string{
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return nil, fmt.Errorf("find peer info: %w", err)
}
wgStats := make(map[string]WGStats)
for k, v := range stats {
sec, err := strconv.ParseInt(v["last_handshake_time_sec"], 10, 64)
if err != nil {
return nil, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(v["last_handshake_time_nsec"], 10, 64)
if err != nil {
return nil, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(v["tx_bytes"], 10, 64)
if err != nil {
return nil, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(v["rx_bytes"], 10, 64)
if err != nil {
return nil, fmt.Errorf("parse rx_bytes: %w", err)
}
wgStats[k] = WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}
}
return wgStats, nil
}
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) { func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
@@ -310,6 +356,44 @@ func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (m
return configFound, nil return configFound, nil
} }
func parsePeerInfo(ipcInput string, searchConfigKeys []string) (map[string]map[string]string, error) {
lines := strings.Split(ipcInput, "\n")
allPeers := map[string]map[string]string{}
var currentPeerKey string
for _, line := range lines {
line = strings.TrimSpace(line)
// Detect new peer section by public key
if strings.HasPrefix(line, "public_key=") {
hexKey := strings.TrimPrefix(line, "public_key=")
keyBytes, _ := hex.DecodeString(hexKey)
wgKey, _ := wgtypes.NewKey(keyBytes)
currentPeerKey = wgKey.String()
if _, exists := allPeers[currentPeerKey]; !exists {
allPeers[currentPeerKey] = map[string]string{}
}
continue
}
// Parse configuration keys for the current peer
if currentPeerKey != "" {
for _, key := range searchConfigKeys {
if strings.HasPrefix(line, key+"=") {
v := strings.SplitN(line, "=", 2)
if len(v) == 2 {
allPeers[currentPeerKey][v[0]] = v[1]
}
}
}
}
}
return allPeers, nil
}
func toWgUserspaceString(wgCfg wgtypes.Config) string { func toWgUserspaceString(wgCfg wgtypes.Config) string {
var sb strings.Builder var sb strings.Builder
if wgCfg.PrivateKey != nil { if wgCfg.PrivateKey != nil {

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"testing" "testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -34,6 +35,19 @@ errno=0
` `
func Test_parsePeerInto(t *testing.T) {
r, err := parsePeerInfo(ipcFixture, []string{
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
t.Errorf("parsePeerInfo() error = %v", err)
}
log.Infof("r: %v", r)
}
func Test_findPeerInfo(t *testing.T) { func Test_findPeerInfo(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -17,4 +17,5 @@ type WGConfigurer interface {
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error
Close() Close()
GetStats(peerKey string) (configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
GetAllStat() (map[string]configurer.WGStats, error)
} }

View File

@@ -27,14 +27,14 @@ import (
type status int type status int
const ( const (
defaultModuleDir = "/lib/modules" unknown status = 1
unknown status = iota unloaded status = 2
unloaded unloading status = 3
unloading loading status = 4
loading live status = 5
live inuse status = 6
inuse defaultModuleDir = "/lib/modules"
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED" envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
) )
type module struct { type module struct {

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/connprofile"
) )
const ( const (
@@ -114,7 +115,13 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) err := w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
if err != nil {
return err
}
connprofile.Profiler.WireGuardConfigured(peerKey)
return nil
} }
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
@@ -208,6 +215,10 @@ func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats(peerKey) return w.configurer.GetStats(peerKey)
} }
func (w *WGIface) GetAllStat() (map[string]configurer.WGStats, error) {
return w.configurer.GetAllStat()
}
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime) timeout := time.NewTimer(maxWaitTime)

View File

@@ -33,4 +33,5 @@ type IWGIface interface {
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
GetAllStat() (map[string]configurer.WGStats, error)
} }

View File

@@ -46,6 +46,7 @@ type ConfigInput struct {
ManagementURL string ManagementURL string
AdminURL string AdminURL string
ConfigPath string ConfigPath string
StateFilePath string
PreSharedKey *string PreSharedKey *string
ServerSSHAllowed *bool ServerSSHAllowed *bool
NATExternalIPs []string NATExternalIPs []string
@@ -105,10 +106,10 @@ type Config struct {
// DNSRouteInterval is the interval in which the DNS routes are updated // DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration DNSRouteInterval time.Duration
//Path to a certificate used for mTLS authentication // Path to a certificate used for mTLS authentication
ClientCertPath string ClientCertPath string
//Path to corresponding private key of ClientCertPath // Path to corresponding private key of ClientCertPath
ClientCertKeyPath string ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"` ClientCertKeyPair *tls.Certificate `json:"-"`
@@ -116,7 +117,7 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) { func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) { if fileExists(configPath) {
err := util.EnforcePermission(configPath) err := util.EnforcePermission(configPath)
if err != nil { if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err) log.Errorf("failed to enforce permission on config dir: %v", err)
@@ -149,7 +150,7 @@ func ReadConfig(configPath string) (*Config, error) {
// UpdateConfig update existing configuration according to input configuration and return with the configuration // UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) { func UpdateConfig(input ConfigInput) (*Config, error) {
if !configFileIsExists(input.ConfigPath) { if !fileExists(input.ConfigPath) {
return nil, status.Errorf(codes.NotFound, "config file doesn't exist") return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
} }
@@ -158,13 +159,13 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
// UpdateOrCreateConfig reads existing config or generates a new one // UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !configFileIsExists(input.ConfigPath) { if !fileExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath) log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input) cfg, err := createNewConfig(input)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err return cfg, err
} }
@@ -185,7 +186,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// WriteOutConfig write put the prepared config to the given path // WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error { func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(path, config) return util.WriteJson(context.Background(), path, config)
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -215,7 +216,7 @@ func update(input ConfigInput) (*Config, error) {
} }
if updated { if updated {
if err := util.WriteJson(input.ConfigPath, config); err != nil { if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err return nil, err
} }
} }
@@ -472,11 +473,19 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
return false return false
} }
func configFileIsExists(path string) bool { func fileExists(path string) bool {
_, err := os.Stat(path) _, err := os.Stat(path)
return !os.IsNotExist(err) return !os.IsNotExist(err)
} }
func createFile(path string) error {
file, err := os.Create(path)
if err != nil {
return err
}
return file.Close()
}
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config. // If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
// The check is performed only for the NetBird's managed version. // The check is performed only for the NetBird's managed version.

View File

@@ -40,6 +40,8 @@ type ConnectClient struct {
statusRecorder *peer.Status statusRecorder *peer.Status
engine *Engine engine *Engine
engineMutex sync.Mutex engineMutex sync.Mutex
persistNetworkMap bool
} }
func NewConnectClient( func NewConnectClient(
@@ -89,6 +91,7 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32, fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager, dnsManager dns.IosDnsManager,
stateFilePath string,
) error { ) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5) debug.SetGCPercent(5)
@@ -97,6 +100,7 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor, FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener, NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager, DnsManager: dnsManager,
StateFilePath: stateFilePath,
} }
return c.run(mobileDependency, nil, nil) return c.run(mobileDependency, nil, nil)
} }
@@ -157,7 +161,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
engineCtx, cancel := context.WithCancel(c.ctx) engineCtx, cancel := context.WithCancel(c.ctx)
defer func() { defer func() {
c.statusRecorder.MarkManagementDisconnected(state.err) _, err := state.Status()
c.statusRecorder.MarkManagementDisconnected(err)
c.statusRecorder.CleanLocalPeerState() c.statusRecorder.CleanLocalPeerState()
cancel() cancel()
}() }()
@@ -231,6 +236,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
relayURLs, token := parseRelayInfo(loginResp) relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 { if len(relayURLs) > 0 {
if token != nil { if token != nil {
if err := relayManager.UpdateToken(token); err != nil { if err := relayManager.UpdateToken(token); err != nil {
@@ -241,9 +247,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", ")) log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
if err = relayManager.Serve(); err != nil { if err = relayManager.Serve(); err != nil {
log.Error(err) log.Error(err)
return wrapErr(err)
} }
c.statusRecorder.SetRelayMgr(relayManager)
} }
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
@@ -258,7 +262,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.engineMutex.Lock() c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
c.engineMutex.Unlock() c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil { if err := c.engine.Start(); err != nil {
@@ -336,6 +340,19 @@ func (c *ConnectClient) Engine() *Engine {
return e return e
} }
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {
return StatusIdle
}
status, err := CtxGetState(c.ctx).Status()
if err != nil {
return StatusIdle
}
return status
}
func (c *ConnectClient) Stop() error { func (c *ConnectClient) Stop() error {
if c == nil { if c == nil {
return nil return nil
@@ -362,6 +379,22 @@ func (c *ConnectClient) isContextCancelled() bool {
} }
} }
// SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored
// network map will be cleared. This functionality is primarily used for debugging
// and should not be enabled during normal operation.
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
c.engineMutex.Lock()
c.persistNetworkMap = enabled
c.engineMutex.Unlock()
engine := c.Engine()
if engine != nil {
engine.SetNetworkMapPersistence(enabled)
}
}
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false nm := false

View File

@@ -7,7 +7,6 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
@@ -323,12 +322,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err) log.Error(err)
} }
// persist dns state right away go func() {
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) // persist dns state right away
defer cancel() if err := s.stateManager.PersistState(s.ctx); err != nil {
if err := s.stateManager.PersistState(ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err)
log.Errorf("Failed to persist dns state: %v", err) }
} }()
if s.searchDomainNotifier != nil { if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
@@ -533,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
// persist dns state right away go func() {
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) if err := s.stateManager.PersistState(s.ctx); err != nil {
defer cancel() l.Errorf("Failed to persist dns state: %v", err)
if err := s.stateManager.PersistState(ctx); err != nil { }
l.Errorf("Failed to persist dns state: %v", err) }()
}
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone() s.addHostRootZone()

View File

@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53, Port: 53,
}, },
}, },
Domains: []string{"customdomain.com"}, Domains: []string{"google.com"},
Primary: false, Primary: false,
}, },
}, },
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
if ips[0] != zoneRecords[0].RData { if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err) t.Fatalf("invalid zone record: %v", err)
} }
_, err = resolver.LookupHost(context.Background(), "customdomain.com") _, err = resolver.LookupHost(context.Background(), "google.com")
if err != nil { if err != nil {
t.Errorf("failed to resolve: %s", err) t.Errorf("failed to resolve: %s", err)
} }

View File

@@ -11,6 +11,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"slices" "slices"
"sort"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -20,6 +21,7 @@ import (
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/proto"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
@@ -37,7 +39,8 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/connprofile"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@@ -61,6 +64,7 @@ import (
const ( const (
PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
) )
var ErrResetConnection = fmt.Errorf("reset connection") var ErrResetConnection = fmt.Errorf("reset connection")
@@ -171,7 +175,12 @@ type Engine struct {
relayManager *relayClient.Manager relayManager *relayClient.Manager
stateManager *statemanager.Manager stateManager *statemanager.Manager
srWatcher *guard.SRWatcher srWatcher *guard.SRWatcher
// Network map persistence
persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap
connSemaphore *semaphoregroup.SemaphoreGroup
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -237,6 +246,18 @@ func NewEngineWithProbes(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
probes: probes, probes: probes,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
}
if runtime.GOOS == "ios" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
if err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
engine.stateManager = statemanager.New(mobileDep.StateFilePath)
} }
if path := statemanager.GetDefaultStatePath(); path != "" { if path := statemanager.GetDefaultStatePath(); path != "" {
engine.stateManager = statemanager.New(path) engine.stateManager = statemanager.New(path)
@@ -271,6 +292,10 @@ func (e *Engine) Stop() error {
e.srWatcher.Close() e.srWatcher.Close()
} }
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
err := e.removeAllPeers() err := e.removeAllPeers()
if err != nil { if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
@@ -297,7 +322,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil { if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err) return fmt.Errorf("failed to stop state manager: %w", err)
} }
if err := e.stateManager.PersistState(ctx); err != nil { if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
@@ -349,8 +374,17 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) e.routeManager = routemanager.NewManager(
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) e.ctx,
e.config.WgPrivateKey.PublicKey().String(),
e.config.DNSRouteInterval,
e.wgInterface,
e.statusRecorder,
e.relayManager,
initialRoutes,
e.stateManager,
)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
} else { } else {
@@ -387,6 +421,8 @@ func (e *Engine) Start() error {
return fmt.Errorf("up wg interface: %w", err) return fmt.Errorf("up wg interface: %w", err)
} }
connprofile.Profiler.WGInterfaceUP(e.wgInterface)
if e.firewall != nil { if e.firewall != nil {
e.acl = acl.NewDefaultManager(e.firewall) e.acl = acl.NewDefaultManager(e.firewall)
} }
@@ -538,6 +574,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
relayMsg := wCfg.GetRelay() relayMsg := wCfg.GetRelay()
if relayMsg != nil { if relayMsg != nil {
// when we receive token we expect valid address list too
c := &auth.Token{ c := &auth.Token{
Payload: relayMsg.GetTokenPayload(), Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(), Signature: relayMsg.GetTokenSignature(),
@@ -546,9 +583,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
log.Errorf("failed to update relay token: %v", err) log.Errorf("failed to update relay token: %v", err)
return fmt.Errorf("update relay token: %w", err) return fmt.Errorf("update relay token: %w", err)
} }
e.relayManager.UpdateServerURLs(relayMsg.Urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.
_ = e.relayManager.Serve()
} else {
e.relayManager.UpdateServerURLs(nil)
} }
// todo update relay address in the relay manager
// todo update signal // todo update signal
} }
@@ -556,13 +600,22 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err return err
} }
if update.GetNetworkMap() != nil { nm := update.GetNetworkMap()
// only apply new changes and ignore old ones if nm == nil {
err := e.updateNetworkMap(update.GetNetworkMap()) return nil
if err != nil {
return err
}
} }
// Store network map if persistence is enabled
if e.persistNetworkMap {
e.latestNetworkMap = nm
log.Debugf("network map persisted with serial %d", nm.GetSerial())
}
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
return err
}
return nil return nil
} }
@@ -641,6 +694,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
}
if e.wgInterface.Address().String() != conf.Address { if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String() oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
@@ -732,7 +789,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
} }
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't // intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil { if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig()) err := e.updateConfig(networkMap.GetPeerConfig())
@@ -767,6 +823,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.clientRoutesMu.Unlock() e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
connprofile.Profiler.NetworkMapUpdate(networkMap.GetRemotePeers())
e.updateOfflinePeers(networkMap.GetOfflinePeers()) e.updateOfflinePeers(networkMap.GetOfflinePeers())
@@ -1001,7 +1058,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
}, },
} }
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1051,6 +1108,7 @@ func (e *Engine) receiveSignalEvents() {
RosenpassAddr: rosenpassAddr, RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(), RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
}) })
connprofile.Profiler.OfferAnswerReceived(msg.Key)
case sProto.Body_ANSWER: case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg) remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil { if err != nil {
@@ -1074,6 +1132,7 @@ func (e *Engine) receiveSignalEvents() {
RosenpassAddr: rosenpassAddr, RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(), RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
}) })
connprofile.Profiler.OfferAnswerReceived(msg.Key)
case sProto.Body_CANDIDATE: case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload) candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
if err != nil { if err != nil {
@@ -1479,8 +1538,59 @@ func (e *Engine) stopDNSServer() {
e.statusRecorder.UpdateDNSStates(nsGroupStates) e.statusRecorder.UpdateDNSStates(nsGroupStates)
} }
// SetNetworkMapPersistence enables or disables network map persistence
func (e *Engine) SetNetworkMapPersistence(enabled bool) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if enabled == e.persistNetworkMap {
return
}
e.persistNetworkMap = enabled
log.Debugf("Network map persistence is set to %t", enabled)
if !enabled {
e.latestNetworkMap = nil
}
}
// GetLatestNetworkMap returns the stored network map if persistence is enabled
func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if !e.persistNetworkMap {
return nil, errors.New("network map persistence is disabled")
}
if e.latestNetworkMap == nil {
//nolint:nilnil
return nil, nil
}
// Create a deep copy to avoid external modifications
nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap)
if !ok {
return nil, fmt.Errorf("failed to clone network map")
}
return nm, nil
}
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {
sort.Slice(check.Files, func(i, j int) bool {
return check.Files[i] < check.Files[j]
})
}
for _, oCheck := range oChecks {
sort.Slice(oCheck.Files, func(i, j int) bool {
return oCheck.Files[i] < oCheck.Files[j]
})
}
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files) return slices.Equal(checks.Files, oChecks.Files)
}) })

View File

@@ -245,12 +245,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
nil) nil)
wgIface := &iface.MockWGIface{ wgIface := &iface.MockWGIface{
NameFunc: func() string { return "utun102" },
RemovePeerFunc: func(peerKey string) error { RemovePeerFunc: func(peerKey string) error {
return nil return nil
}, },
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
_, _, err = engine.routeManager.Init()
require.NoError(t, err)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }
@@ -1006,6 +1009,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
} }
} }
func Test_CheckFilesEqual(t *testing.T) {
testCases := []struct {
name string
inputChecks1 []*mgmtProto.Checks
inputChecks2 []*mgmtProto.Checks
expectedBool bool
}{
{
name: "Equal Files In Equal Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
expectedBool: true,
},
{
name: "Equal Files In Reverse Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile2",
"testfile1",
},
},
},
expectedBool: true,
},
{
name: "Unequal Files Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile3",
},
},
},
expectedBool: false,
},
{
name: "Compared With Empty Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{},
},
},
expectedBool: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
})
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {

View File

@@ -19,4 +19,5 @@ type MobileDependency struct {
// iOS only // iOS only
DnsManager dns.IosDnsManager DnsManager dns.IosDnsManager
FileDescriptor int32 FileDescriptor int32
StateFilePath string
} }

View File

@@ -23,6 +23,7 @@ import (
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
type ConnPriority int type ConnPriority int
@@ -83,7 +84,6 @@ type Conn struct {
signaler *Signaler signaler *Signaler
relayManager *relayClient.Manager relayManager *relayClient.Manager
allowedIP net.IP allowedIP net.IP
allowedNet string
handshaker *Handshaker handshaker *Handshaker
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
@@ -105,13 +105,14 @@ type Conn struct {
wgProxyICE wgproxy.Proxy wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
guard *guard.Guard guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil { if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err) log.Errorf("failed to parse allowedIPS: %v", err)
return nil, err return nil, err
@@ -129,9 +130,9 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
signaler: signaler, signaler: signaler,
relayManager: relayManager, relayManager: relayManager,
allowedIP: allowedIP, allowedIP: allowedIP,
allowedNet: allowedNet.String(),
statusRelay: NewAtomicConnStatus(), statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
} }
rFns := WorkerRelayCallbacks{ rFns := WorkerRelayCallbacks{
@@ -171,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used. // be used.
func (conn *Conn) Open() { func (conn *Conn) Open() {
conn.semaphore.Add(conn.ctx)
conn.log.Debugf("open connection to peer") conn.log.Debugf("open connection to peer")
conn.mu.Lock() conn.mu.Lock()
@@ -193,6 +195,7 @@ func (conn *Conn) Open() {
} }
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
defer conn.semaphore.Done(conn.ctx)
conn.waitInitialRandomSleepTime(ctx) conn.waitInitialRandomSleepTime(ctx)
err := conn.handshaker.sendOffer() err := conn.handshaker.sendOffer()
@@ -594,14 +597,13 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
} }
if conn.onConnected != nil { if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr) conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr)
} }
} }
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) { func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
minWait := 100 maxWait := 300
maxWait := 800 duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
timeout := time.NewTimer(duration) timeout := time.NewTimer(duration)
defer timeout.Stop() defer timeout.Stop()

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
var connConf = ConnConfig{ var connConf = ConnConfig{
@@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher) conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil { if err != nil {
return return
} }
@@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil { if err != nil {
return return
} }
@@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil { if err != nil {
return return
} }
@@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil { if err != nil {
return return
} }

View File

@@ -4,6 +4,7 @@ import (
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/connprofile"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
) )
@@ -66,5 +67,6 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
return err return err
} }
connprofile.Profiler.OfferSent(remoteKey)
return nil return nil
} }

View File

@@ -676,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
// extend the list of stun, turn servers with relay address // extend the list of stun, turn servers with relay address
relayStates := slices.Clone(d.relayStates) relayStates := slices.Clone(d.relayStates)
var relayState relay.ProbeResult
// if the server connection is not established then we will use the general address // if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address // in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress() instanceAddr, err := d.relayMgr.RelayInstanceAddress()
if err != nil { if err != nil {
// TODO add their status // TODO add their status
if errors.Is(err, relayClient.ErrRelayClientNotConnected) { for _, r := range d.relayMgr.ServerURLs() {
for _, r := range d.relayMgr.ServerURLs() { relayStates = append(relayStates, relay.ProbeResult{
relayStates = append(relayStates, relay.ProbeResult{ URI: r,
URI: r, Err: err,
}) })
}
return relayStates
} }
relayState.Err = err return relayStates
} }
relayState.URI = instanceAddr relayState := relay.ProbeResult{
URI: instanceAddr,
}
return append(relayStates, relayState) return append(relayStates, relayState)
} }

View File

@@ -46,8 +46,6 @@ type WorkerICE struct {
hasRelayOnLocally bool hasRelayOnLocally bool
conn WorkerICECallbacks conn WorkerICECallbacks
selectedPriority ConnPriority
agent *ice.Agent agent *ice.Agent
muxAgent sync.Mutex muxAgent sync.Mutex
@@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = icemaker.CandidateTypesP2P() preferredCandidateTypes = icemaker.CandidateTypesP2P()
} else { } else {
w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = icemaker.CandidateTypes() preferredCandidateTypes = icemaker.CandidateTypes()
} }
@@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RelayedOnLocal: isRelayCandidate(pair.Local), RelayedOnLocal: isRelayCandidate(pair.Local),
} }
w.log.Debugf("on ICE conn read to use ready") w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci) go w.conn.OnConnReady(selectedPriority(pair), ci)
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -268,7 +264,13 @@ func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration // wait local endpoint configuration
time.Sleep(time.Second) time.Sleep(time.Second)
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort)) addrString := pair.Remote.Address()
parsed, err := netip.ParseAddr(addrString)
if (err == nil) && (parsed.Is6()) {
addrString = fmt.Sprintf("[%s]", addrString)
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
if err != nil { if err != nil {
w.log.Warnf("got an error while resolving the udp address, err: %s", err) w.log.Warnf("got an error while resolving the udp address, err: %s", err)
return return
@@ -394,3 +396,11 @@ func isRelayed(pair *ice.CandidatePair) bool {
} }
return false return false
} }
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
if isRelayed(pair) {
return connPriorityICETurn
} else {
return connPriorityICEP2P
}
}

View File

@@ -1,6 +1,7 @@
package routemanager package routemanager
import ( import (
"fmt"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
@@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute: "route1", currentRoute: "route1",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{
name: "relayed routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "p2p routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{ {
name: "current route with bad score should be changed to route with better score", name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
@@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
} }
// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{ currentRoute := &route.Route{

View File

@@ -32,7 +32,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
@@ -59,6 +59,7 @@ type DefaultManager struct {
routeRefCounter *refcounter.RouteRefCounter routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
stateManager *statemanager.Manager
} }
func NewManager( func NewManager(
@@ -69,6 +70,7 @@ func NewManager(
statusRecorder *peer.Status, statusRecorder *peer.Status,
relayMgr *relayClient.Manager, relayMgr *relayClient.Manager,
initialRoutes []*route.Route, initialRoutes []*route.Route,
stateManager *statemanager.Manager,
) *DefaultManager { ) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
@@ -80,12 +82,12 @@ func NewManager(
dnsRouteInterval: dnsRouteInterval, dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork), clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
relayMgr: relayMgr, relayMgr: relayMgr,
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps, sysOps: sysOps,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
notifier: notifier, notifier: notifier,
stateManager: stateManager,
} }
dm.routeRefCounter = refcounter.New( dm.routeRefCounter = refcounter.New(
@@ -121,7 +123,9 @@ func NewManager(
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
m.routeSelector = m.initSelector()
if nbnet.CustomRoutingDisabled() { if nbnet.CustomRoutingDisabled() {
return nil, nil, nil return nil, nil, nil
} }
@@ -137,14 +141,36 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook
ips := resolveURLsToIPs(initialAddresses) ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err) return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
log.Info("Routing setup complete") log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil return beforePeerHook, afterPeerHook, nil
} }
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
var state *SelectorState
m.stateManager.RegisterState(state)
// restore selector state if it exists
if err := m.stateManager.LoadState(state); err != nil {
log.Warnf("failed to load state: %v", err)
return routeselector.NewRouteSelector()
}
if state := m.stateManager.GetState(state); state != nil {
if selector, ok := state.(*SelectorState); ok {
return (*routeselector.RouteSelector)(selector)
}
log.Warnf("failed to convert state with type %T to SelectorState", state)
}
return routeselector.NewRouteSelector()
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
@@ -252,6 +278,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
} }
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
log.Errorf("failed to update state: %v", err)
}
} }
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list // stopObsoleteClients stops the client network watcher for the networks that are not in the new list

View File

@@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm") statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
_, _, err = routeManager.Init(nil) _, _, err = routeManager.Init()
require.NoError(t, err, "should init route manager") require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil) defer routeManager.Stop(nil)

View File

@@ -21,7 +21,7 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager) StopFunc func(manager *statemanager.Manager)
} }
func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil return nil, nil, nil
} }

View File

@@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct { type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys // refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O] refCountMap map[Key]Ref[O]
refCountMu sync.Mutex mu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal // idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key idMap map[string][]Key
idMu sync.Mutex
add AddFunc[Key, I, O] add AddFunc[Key, I, O]
remove RemoveFunc[Key, O] remove RemoveFunc[Key, O]
} }
@@ -72,13 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
} }
// LoadData loads the data from the existing counter // LoadData loads the data from the existing counter
// The passed counter should not be used any longer after calling this function.
func (rm *Counter[Key, I, O]) LoadData( func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O], existingCounter *Counter[Key, I, O],
) { ) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock() existingCounter.mu.Lock()
defer rm.idMu.Unlock() defer existingCounter.mu.Unlock()
rm.refCountMap = existingCounter.refCountMap rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap rm.idMap = existingCounter.idMap
@@ -87,8 +87,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key. // Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false. // If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
ref, ok := rm.refCountMap[key] ref, ok := rm.refCountMap[key]
return ref, ok return ref, ok
@@ -97,9 +97,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key. // Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called. // If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
return rm.increment(key, in)
}
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key] ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
@@ -126,10 +130,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID. // IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called. // If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock() rm.mu.Lock()
defer rm.idMu.Unlock() defer rm.mu.Unlock()
ref, err := rm.Increment(key, in) ref, err := rm.increment(key, in)
if err != nil { if err != nil {
return ref, fmt.Errorf("with ID: %w", err) return ref, fmt.Errorf("with ID: %w", err)
} }
@@ -141,9 +145,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key. // Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called. // If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
return rm.decrement(key)
}
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key] ref, ok := rm.refCountMap[key]
if !ok { if !ok {
logCallerF("No reference found for key %v", key) logCallerF("No reference found for key %v", key)
@@ -168,12 +175,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID. // DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called. // If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock() rm.mu.Lock()
defer rm.idMu.Unlock() defer rm.mu.Unlock()
var merr *multierror.Error var merr *multierror.Error
for _, key := range rm.idMap[id] { for _, key := range rm.idMap[id] {
if _, err := rm.Decrement(key); err != nil { if _, err := rm.decrement(key); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
} }
@@ -184,10 +191,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
// Flush removes all references and calls RemoveFunc for each key. // Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error { func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error var merr *multierror.Error
for key := range rm.refCountMap { for key := range rm.refCountMap {
@@ -206,10 +211,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
// Clear removes all references without calling RemoveFunc. // Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() { func (rm *Counter[Key, I, O]) Clear() {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
clear(rm.refCountMap) clear(rm.refCountMap)
clear(rm.idMap) clear(rm.idMap)
@@ -217,10 +220,8 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter. // MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
return json.Marshal(struct { return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"` RefCountMap map[Key]Ref[O] `json:"refCountMap"`
@@ -233,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements the json.Unmarshaler interface for Counter. // UnmarshalJSON implements the json.Unmarshaler interface for Counter.
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.mu.Lock()
defer rm.mu.Unlock()
var temp struct { var temp struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"` RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"` IDMap map[string][]Key `json:"idMap"`
@@ -243,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.refCountMap = temp.RefCountMap rm.refCountMap = temp.RefCountMap
rm.idMap = temp.IDMap rm.idMap = temp.IDMap
if temp.RefCountMap == nil {
temp.RefCountMap = map[Key]Ref[O]{}
}
if temp.IDMap == nil {
temp.IDMap = map[string][]Key{}
}
return nil return nil
} }

View File

@@ -0,0 +1,19 @@
package routemanager
import (
"github.com/netbirdio/netbird/client/internal/routeselector"
)
type SelectorState routeselector.RouteSelector
func (s *SelectorState) Name() string {
return "routeselector_state"
}
func (s *SelectorState) MarshalJSON() ([]byte, error) {
return (*routeselector.RouteSelector)(s).MarshalJSON()
}
func (s *SelectorState) UnmarshalJSON(data []byte) error {
return (*routeselector.RouteSelector)(s).UnmarshalJSON(data)
}

View File

@@ -2,31 +2,28 @@ package systemops
import ( import (
"net/netip" "net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
) )
type ShutdownState struct { type ShutdownState ExclusionCounter
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
return "route_state" return "route_state"
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()
if s.Counter == nil {
return nil
}
sysops := NewSysOps(nil, nil) sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter) sysops.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush() return sysops.refCounter.Flush()
} }
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
return (*ExclusionCounter)(s).MarshalJSON()
}
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}

View File

@@ -57,30 +57,19 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, refcounter.ErrIgnore return nexthop, refcounter.ErrIgnore
} }
r.updateState(stateManager)
return nexthop, err return nexthop, err
}, },
func(prefix netip.Prefix, nexthop Nexthop) error { r.removeFromRouteTable,
// remove from state even if we have trouble removing it from the route table
// it could be already gone
r.updateState(stateManager)
return r.removeFromRouteTable(prefix, nexthop)
},
) )
r.refCounter = refCounter r.refCounter = refCounter
return r.setupHooks(initAddresses) return r.setupHooks(initAddresses, stateManager)
} }
// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) { func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager) if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
state.Counter = r.refCounter
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
} }
@@ -336,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {
@@ -347,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("adding route reference: %v", err) return fmt.Errorf("adding route reference: %v", err)
} }
r.updateState(stateManager)
return nil return nil
} }
afterHook := func(connID nbnet.ConnectionID) error { afterHook := func(connID nbnet.ConnectionID) error {
@@ -354,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("remove route reference: %w", err) return fmt.Errorf("remove route reference: %w", err)
} }
r.updateState(stateManager)
return nil return nil
} }
@@ -532,14 +525,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes // Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix return isVpn, longestPrefix
} }
func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = &ShutdownState{}
}
return shutdownState
}

View File

@@ -55,7 +55,7 @@ type ruleParams struct {
// isLegacy determines whether to use the legacy routing setup // isLegacy determines whether to use the legacy routing setup
func isLegacy() bool { func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
} }
// setIsLegacy sets the legacy routing setup // setIsLegacy sets the legacy routing setup
@@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() { defer func() {
if err != nil { if err != nil {
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
@@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
} }
} }
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
return nil, nil, nil return nil, nil, nil
} }

View File

@@ -230,10 +230,13 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
if idx != 0 { if idx != 0 {
intf, err := net.InterfaceByIndex(idx) intf, err := net.InterfaceByIndex(idx)
if err != nil { if err != nil {
return update, fmt.Errorf("get interface name: %w", err) log.Warnf("failed to get interface name for index %d: %v", idx, err)
update.Interface = &net.Interface{
Index: idx,
}
} else {
update.Interface = intf
} }
update.Interface = intf
} }
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface)

View File

@@ -1,8 +1,10 @@
package routeselector package routeselector
import ( import (
"encoding/json"
"fmt" "fmt"
"slices" "slices"
"sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -12,6 +14,7 @@ import (
) )
type RouteSelector struct { type RouteSelector struct {
mu sync.RWMutex
selectedRoutes map[route.NetID]struct{} selectedRoutes map[route.NetID]struct{}
selectAll bool selectAll bool
} }
@@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector {
// SelectRoutes updates the selected routes based on the provided route IDs. // SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error { func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if !appendRoute { if !appendRoute {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
@@ -46,6 +52,9 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
// SelectAllRoutes sets the selector to select all routes. // SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() { func (rs *RouteSelector) SelectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.selectAll = true rs.selectAll = true
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
@@ -53,6 +62,9 @@ func (rs *RouteSelector) SelectAllRoutes() {
// DeselectRoutes removes specific routes from the selection. // DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode. // If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.selectAll { if rs.selectAll {
rs.selectAll = false rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
@@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
// DeselectAllRoutes deselects all routes, effectively disabling route selection. // DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() { func (rs *RouteSelector) DeselectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.selectAll = false rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
// IsSelected checks if a specific route is selected. // IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.selectAll { if rs.selectAll {
return true return true
} }
@@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
// FilterSelected removes unselected routes from the provided map. // FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.selectAll { if rs.selectAll {
return maps.Clone(routes) return maps.Clone(routes)
} }
@@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
} }
return filtered return filtered
} }
// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
defer rs.mu.RUnlock()
return json.Marshal(struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}{
SelectAll: rs.selectAll,
SelectedRoutes: rs.selectedRoutes,
})
}
// UnmarshalJSON implements the json.Unmarshaler interface
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
rs.mu.Lock()
defer rs.mu.Unlock()
// Check for null or empty JSON
if len(data) == 0 || string(data) == "null" {
rs.selectedRoutes = map[route.NetID]struct{}{}
rs.selectAll = true
return nil
}
var temp struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
rs.selectedRoutes = temp.SelectedRoutes
rs.selectAll = temp.SelectAll
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
return nil
}

View File

@@ -273,3 +273,88 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
"route2|192.168.0.0/16": {}, "route2|192.168.0.0/16": {},
}, filtered) }, filtered)
} }
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialRoutes := []route.NetID{"route1", "route2", "route3"}
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}
tests := []struct {
name string
initialState func(rs *routeselector.RouteSelector) error // Setup initial state
wantNewSelected []route.NetID // Expected selected routes after new routes appear
}{
{
name: "New routes with initial selectAll state",
initialState: func(rs *routeselector.RouteSelector) error {
rs.SelectAllRoutes()
return nil
},
// When selectAll is true, all routes including new ones should be selected
wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"},
},
{
name: "New routes after specific selection",
initialState: func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes)
},
// When specific routes were selected, new routes should remain unselected
wantNewSelected: []route.NetID{"route1", "route2"},
},
{
name: "New routes after deselect all",
initialState: func(rs *routeselector.RouteSelector) error {
rs.DeselectAllRoutes()
return nil
},
// After deselect all, new routes should remain unselected
wantNewSelected: []route.NetID{},
},
{
name: "New routes after deselecting specific routes",
initialState: func(rs *routeselector.RouteSelector) error {
rs.SelectAllRoutes()
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
},
// After deselecting specific routes, new routes should remain unselected
wantNewSelected: []route.NetID{"route2", "route3"},
},
{
name: "New routes after selecting with append",
initialState: func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes)
},
// When routes were appended, new routes should remain unselected
wantNewSelected: []route.NetID{"route1"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
// Setup initial state
err := tt.initialState(rs)
require.NoError(t, err)
// Verify selection state with new routes
for _, id := range newRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantNewSelected, id),
"Route %s selection state incorrect", id)
}
// Additional verification using FilterSelected
routes := route.HAMap{
"route1|10.0.0.0/8": {},
"route2|192.168.0.0/16": {},
"route3|172.16.0.0/12": {},
"route4|10.10.0.0/16": {},
"route5|192.168.1.0/24": {},
}
filtered := rs.FilterSelected(routes)
expectedLen := len(tt.wantNewSelected)
assert.Equal(t, expectedLen, len(filtered),
"FilterSelected returned wrong number of routes, got %d want %d", len(filtered), expectedLen)
})
}
}

View File

@@ -16,14 +16,39 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/util"
)
const (
errStateNotRegistered = "state %s not registered"
errLoadStateFile = "load state file: %w"
) )
// State interface defines the methods that all state types must implement // State interface defines the methods that all state types must implement
type State interface { type State interface {
Name() string Name() string
}
// CleanableState interface extends State with cleanup capability
type CleanableState interface {
State
Cleanup() error Cleanup() error
} }
// RawState wraps raw JSON data for unregistered states
type RawState struct {
data json.RawMessage
}
func (r *RawState) Name() string {
return "" // This is a placeholder implementation
}
// MarshalJSON implements json.Marshaler to preserve the original JSON
func (r *RawState) MarshalJSON() ([]byte, error) {
return r.data, nil
}
// Manager handles the persistence and management of various states // Manager handles the persistence and management of various states
type Manager struct { type Manager struct {
mu sync.Mutex mu sync.Mutex
@@ -73,15 +98,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if m.cancel != nil { if m.cancel == nil {
m.cancel() return nil
}
m.cancel()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-m.done: case <-m.done:
return nil
}
} }
return nil return nil
@@ -139,7 +164,7 @@ func (m *Manager) setState(name string, state State) error {
defer m.mu.Unlock() defer m.mu.Unlock()
if _, exists := m.states[name]; !exists { if _, exists := m.states[name]; !exists {
return fmt.Errorf("state %s not registered", name) return fmt.Errorf(errStateNotRegistered, name)
} }
m.states[name] = state m.states[name] = state
@@ -148,6 +173,63 @@ func (m *Manager) setState(name string, state State) error {
return nil return nil
} }
// DeleteStateByName handles deletion of states without cleanup.
// It doesn't require the state to be registered.
func (m *Manager) DeleteStateByName(stateName string) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil
}
if _, exists := rawStates[stateName]; !exists {
return fmt.Errorf("state %s not found", stateName)
}
// Mark state as deleted by setting it to nil and marking it dirty
m.states[stateName] = nil
m.dirty[stateName] = struct{}{}
return nil
}
// DeleteAllStates removes all states.
func (m *Manager) DeleteAllStates() (int, error) {
if m == nil {
return 0, nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return 0, fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return 0, nil
}
count := len(rawStates)
// Mark all states as deleted and dirty
for name := range rawStates {
m.states[name] = nil
m.dirty[name] = struct{}{}
}
return count, nil
}
func (m *Manager) periodicStateSave(ctx context.Context) { func (m *Manager) periodicStateSave(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second) ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -178,25 +260,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil return nil
} }
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
done := make(chan error, 1) done := make(chan error, 1)
start := time.Now()
go func() { go func() {
data, err := json.MarshalIndent(m.states, "", " ") done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}
// nolint:gosec
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
done <- fmt.Errorf("write state file: %w", err)
return
}
done <- nil
}() }()
select { select {
@@ -208,63 +283,175 @@ func (m *Manager) PersistState(ctx context.Context) error {
} }
} }
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) log.Debugf("persisted states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
clear(m.dirty) clear(m.dirty)
return nil return nil
} }
// loadState loads the existing state from the state file // loadStateFile reads and unmarshals the state file into a map of raw JSON messages
func (m *Manager) loadState() error { func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) {
data, err := os.ReadFile(m.filePath) data, err := os.ReadFile(m.filePath)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist") log.Debug("state file does not exist")
return nil return nil, nil // nolint:nilnil
} }
return fmt.Errorf("read state file: %w", err) return nil, fmt.Errorf("read state file: %w", err)
} }
var rawStates map[string]json.RawMessage var rawStates map[string]json.RawMessage
if err := json.Unmarshal(data, &rawStates); err != nil { if err := json.Unmarshal(data, &rawStates); err != nil {
log.Warn("State file appears to be corrupted, attempting to delete it") if deleteCorrupt {
if err := os.Remove(m.filePath); err != nil { log.Warn("State file appears to be corrupted, attempting to delete it", err)
log.Errorf("Failed to delete corrupted state file: %v", err) if err := os.Remove(m.filePath); err != nil {
} else { log.Errorf("Failed to delete corrupted state file: %v", err)
log.Info("State file deleted") } else {
log.Info("State file deleted")
}
} }
return fmt.Errorf("unmarshal states: %w", err) return nil, fmt.Errorf("unmarshal states: %w", err)
} }
var merr *multierror.Error return rawStates, nil
}
for name, rawState := range rawStates { // loadSingleRawState unmarshals a raw state into a concrete state object
stateType, ok := m.stateTypes[name] func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
if !ok { stateType, ok := m.stateTypes[name]
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) if !ok {
continue return nil, fmt.Errorf(errStateNotRegistered, name)
} }
if string(rawState) == "null" { if string(rawState) == "null" {
continue return nil, nil //nolint:nilnil
} }
statePtr := reflect.New(stateType).Interface().(State) statePtr := reflect.New(stateType).Interface().(State)
if err := json.Unmarshal(rawState, statePtr); err != nil { if err := json.Unmarshal(rawState, statePtr); err != nil {
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
continue }
}
m.states[name] = statePtr return statePtr, nil
}
// LoadState loads a specific state from the state file
func (m *Manager) LoadState(state State) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return err
}
if rawStates == nil {
return nil
}
name := state.Name()
rawState, exists := rawStates[name]
if !exists {
return nil
}
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
return err
}
m.states[name] = loadedState
if loadedState != nil {
log.Debugf("loaded state: %s", name) log.Debugf("loaded state: %s", name)
} }
return nberrors.FormatErrorOrNil(merr) return nil
} }
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. // cleanupSingleState handles the cleanup of a specific state and returns any error.
// If the cleanup is successful, the state is marked for deletion. // The caller must hold the mutex.
func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) error {
// For unregistered states, preserve the raw JSON
if _, registered := m.stateTypes[name]; !registered {
m.states[name] = &RawState{data: rawState}
return nil
}
// Load the state
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
return err
}
if loadedState == nil {
return nil
}
// Check if state supports cleanup
cleanableState, isCleanable := loadedState.(CleanableState)
if !isCleanable {
// If it doesn't support cleanup, keep it as-is
m.states[name] = loadedState
return nil
}
// Perform cleanup
log.Infof("cleaning up state %s", name)
if err := cleanableState.Cleanup(); err != nil {
// On cleanup error, preserve the state
m.states[name] = loadedState
return fmt.Errorf("cleanup state: %w", err)
}
// Successfully cleaned up - mark for deletion
m.states[name] = nil
m.dirty[name] = struct{}{}
return nil
}
// CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState.
// Returns an error if the state doesn't exist, isn't registered, or cleanup fails.
func (m *Manager) CleanupStateByName(name string) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
// Check if state is registered
if _, registered := m.stateTypes[name]; !registered {
return fmt.Errorf(errStateNotRegistered, name)
}
// Load raw states from file
rawStates, err := m.loadStateFile(false)
if err != nil {
return err
}
if rawStates == nil {
return nil
}
// Check if state exists in file
rawState, exists := rawStates[name]
if !exists {
return nil
}
if err := m.cleanupSingleState(name, rawState); err != nil {
return fmt.Errorf("%s: %w", name, err)
}
return nil
}
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
// Unregistered states are preserved in their original state.
func (m *Manager) PerformCleanup() error { func (m *Manager) PerformCleanup() error {
if m == nil { if m == nil {
return nil return nil
@@ -273,26 +460,63 @@ func (m *Manager) PerformCleanup() error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if err := m.loadState(); err != nil { // Load raw states from file
log.Warnf("Failed to load state during cleanup: %v", err) rawStates, err := m.loadStateFile(true)
if err != nil {
return fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil
} }
var merr *multierror.Error var merr *multierror.Error
for name, state := range m.states {
if state == nil {
// If no state was found in the state file, we don't mark the state dirty nor return an error
continue
}
log.Infof("client was not shut down properly, cleaning up %s", name) // Process each state in the file
if err := state.Cleanup(); err != nil { for name, rawState := range rawStates {
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) if err := m.cleanupSingleState(name, rawState); err != nil {
} else { merr = multierror.Append(merr, fmt.Errorf("%s: %w", name, err))
// mark for deletion on cleanup success
m.states[name] = nil
m.dirty[name] = struct{}{}
} }
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// GetSavedStateNames returns all state names that are currently saved in the state file.
func (m *Manager) GetSavedStateNames() ([]string, error) {
if m == nil {
return nil, nil
}
rawStates, err := m.loadStateFile(false)
if err != nil {
return nil, fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil, nil
}
var states []string
for name, state := range rawStates {
if len(state) != 0 && string(state) != "null" {
states = append(states, name)
}
}
return states, nil
}
func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error
func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during marshal: %v", r)
}
}()
bs, err = json.Marshal(v)
}()
return bs, err
}

View File

@@ -4,32 +4,20 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
) )
// GetDefaultStatePath returns the path to the state file based on the operating system // GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. // It returns an empty string if the path cannot be determined.
func GetDefaultStatePath() string { func GetDefaultStatePath() string {
var path string
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux": case "darwin", "linux":
path = "/var/lib/netbird/state.json" return "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly": case "freebsd", "openbsd", "netbsd", "dragonfly":
path = "/var/db/netbird/state.json" return "/var/db/netbird/state.json"
// ios/android don't need state
default:
return ""
} }
dir := filepath.Dir(path) return ""
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return ""
}
return path
} }

View File

@@ -59,6 +59,7 @@ func init() {
// Client struct manage the life circle of background service // Client struct manage the life circle of background service
type Client struct { type Client struct {
cfgFile string cfgFile string
stateFile string
recorder *peer.Status recorder *peer.Status
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
@@ -73,9 +74,10 @@ type Client struct {
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
return &Client{ return &Client{
cfgFile: cfgFile, cfgFile: cfgFile,
stateFile: stateFile,
deviceName: deviceName, deviceName: deviceName,
osName: osName, osName: osName,
osVersion: osVersion, osVersion: osVersion,
@@ -91,7 +93,8 @@ func (c *Client) Run(fd int32, interfaceName string) error {
log.Infof("Starting NetBird client") log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName) log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
}) })
if err != nil { if err != nil {
return err return err
@@ -124,7 +127,7 @@ func (c *Client) Run(fd int32, interfaceName string) error {
cfg.WgIface = interfaceName cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager) return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources

View File

@@ -10,9 +10,10 @@ type Preferences struct {
} }
// NewPreferences create new Preferences instance // NewPreferences create new Preferences instance
func NewPreferences(configPath string) *Preferences { func NewPreferences(configPath string, stateFilePath string) *Preferences {
ci := internal.ConfigInput{ ci := internal.ConfigInput{
ConfigPath: configPath, ConfigPath: configPath,
StateFilePath: stateFilePath,
} }
return &Preferences{ci} return &Preferences{ci}
} }

View File

@@ -9,7 +9,8 @@ import (
func TestPreferences_DefaultValues(t *testing.T) { func TestPreferences_DefaultValues(t *testing.T) {
cfgFile := filepath.Join(t.TempDir(), "netbird.json") cfgFile := filepath.Join(t.TempDir(), "netbird.json")
p := NewPreferences(cfgFile) stateFile := filepath.Join(t.TempDir(), "state.json")
p := NewPreferences(cfgFile, stateFile)
defaultVar, err := p.GetAdminURL() defaultVar, err := p.GetAdminURL()
if err != nil { if err != nil {
t.Fatalf("failed to read default value: %s", err) t.Fatalf("failed to read default value: %s", err)
@@ -42,7 +43,8 @@ func TestPreferences_DefaultValues(t *testing.T) {
func TestPreferences_ReadUncommitedValues(t *testing.T) { func TestPreferences_ReadUncommitedValues(t *testing.T) {
exampleString := "exampleString" exampleString := "exampleString"
cfgFile := filepath.Join(t.TempDir(), "netbird.json") cfgFile := filepath.Join(t.TempDir(), "netbird.json")
p := NewPreferences(cfgFile) stateFile := filepath.Join(t.TempDir(), "state.json")
p := NewPreferences(cfgFile, stateFile)
p.SetAdminURL(exampleString) p.SetAdminURL(exampleString)
resp, err := p.GetAdminURL() resp, err := p.GetAdminURL()
@@ -79,7 +81,8 @@ func TestPreferences_Commit(t *testing.T) {
exampleURL := "https://myurl.com:443" exampleURL := "https://myurl.com:443"
examplePresharedKey := "topsecret" examplePresharedKey := "topsecret"
cfgFile := filepath.Join(t.TempDir(), "netbird.json") cfgFile := filepath.Join(t.TempDir(), "netbird.json")
p := NewPreferences(cfgFile) stateFile := filepath.Join(t.TempDir(), "state.json")
p := NewPreferences(cfgFile, stateFile)
p.SetAdminURL(exampleURL) p.SetAdminURL(exampleURL)
p.SetManagementURL(exampleURL) p.SetManagementURL(exampleURL)
@@ -90,7 +93,7 @@ func TestPreferences_Commit(t *testing.T) {
t.Fatalf("failed to save changes: %s", err) t.Fatalf("failed to save changes: %s", err)
} }
p = NewPreferences(cfgFile) p = NewPreferences(cfgFile, stateFile)
resp, err := p.GetAdminURL() resp, err := p.GetAdminURL()
if err != nil { if err != nil {
t.Fatalf("failed to read admin url: %s", err) t.Fatalf("failed to read admin url: %s", err)

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v3.21.12 // protoc v4.23.4
// source: daemon.proto // source: daemon.proto
package proto package proto
@@ -2103,6 +2103,434 @@ func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{30} return file_daemon_proto_rawDescGZIP(), []int{30}
} }
// State represents a daemon state entry
type State struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
}
func (x *State) Reset() {
*x = State{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[31]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *State) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*State) ProtoMessage() {}
func (x *State) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[31]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use State.ProtoReflect.Descriptor instead.
func (*State) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{31}
}
func (x *State) GetName() string {
if x != nil {
return x.Name
}
return ""
}
// ListStatesRequest is empty as it requires no parameters
type ListStatesRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *ListStatesRequest) Reset() {
*x = ListStatesRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[32]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListStatesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListStatesRequest) ProtoMessage() {}
func (x *ListStatesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[32]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead.
func (*ListStatesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{32}
}
// ListStatesResponse contains a list of states
type ListStatesResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
States []*State `protobuf:"bytes,1,rep,name=states,proto3" json:"states,omitempty"`
}
func (x *ListStatesResponse) Reset() {
*x = ListStatesResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[33]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListStatesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListStatesResponse) ProtoMessage() {}
func (x *ListStatesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[33]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead.
func (*ListStatesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{33}
}
func (x *ListStatesResponse) GetStates() []*State {
if x != nil {
return x.States
}
return nil
}
// CleanStateRequest for cleaning states
type CleanStateRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"`
All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"`
}
func (x *CleanStateRequest) Reset() {
*x = CleanStateRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[34]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CleanStateRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CleanStateRequest) ProtoMessage() {}
func (x *CleanStateRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[34]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead.
func (*CleanStateRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{34}
}
func (x *CleanStateRequest) GetStateName() string {
if x != nil {
return x.StateName
}
return ""
}
func (x *CleanStateRequest) GetAll() bool {
if x != nil {
return x.All
}
return false
}
// CleanStateResponse contains the result of the clean operation
type CleanStateResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
CleanedStates int32 `protobuf:"varint,1,opt,name=cleaned_states,json=cleanedStates,proto3" json:"cleaned_states,omitempty"`
}
func (x *CleanStateResponse) Reset() {
*x = CleanStateResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[35]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CleanStateResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CleanStateResponse) ProtoMessage() {}
func (x *CleanStateResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[35]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead.
func (*CleanStateResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{35}
}
func (x *CleanStateResponse) GetCleanedStates() int32 {
if x != nil {
return x.CleanedStates
}
return 0
}
// DeleteStateRequest for deleting states
type DeleteStateRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"`
All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"`
}
func (x *DeleteStateRequest) Reset() {
*x = DeleteStateRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[36]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DeleteStateRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DeleteStateRequest) ProtoMessage() {}
func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[36]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead.
func (*DeleteStateRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{36}
}
func (x *DeleteStateRequest) GetStateName() string {
if x != nil {
return x.StateName
}
return ""
}
func (x *DeleteStateRequest) GetAll() bool {
if x != nil {
return x.All
}
return false
}
// DeleteStateResponse contains the result of the delete operation
type DeleteStateResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
DeletedStates int32 `protobuf:"varint,1,opt,name=deleted_states,json=deletedStates,proto3" json:"deleted_states,omitempty"`
}
func (x *DeleteStateResponse) Reset() {
*x = DeleteStateResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[37]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DeleteStateResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DeleteStateResponse) ProtoMessage() {}
func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[37]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead.
func (*DeleteStateResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{37}
}
func (x *DeleteStateResponse) GetDeletedStates() int32 {
if x != nil {
return x.DeletedStates
}
return 0
}
type SetNetworkMapPersistenceRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"`
}
func (x *SetNetworkMapPersistenceRequest) Reset() {
*x = SetNetworkMapPersistenceRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[38]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetNetworkMapPersistenceRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetNetworkMapPersistenceRequest) ProtoMessage() {}
func (x *SetNetworkMapPersistenceRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[38]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetNetworkMapPersistenceRequest.ProtoReflect.Descriptor instead.
func (*SetNetworkMapPersistenceRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{38}
}
func (x *SetNetworkMapPersistenceRequest) GetEnabled() bool {
if x != nil {
return x.Enabled
}
return false
}
type SetNetworkMapPersistenceResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SetNetworkMapPersistenceResponse) Reset() {
*x = SetNetworkMapPersistenceResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[39]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetNetworkMapPersistenceResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetNetworkMapPersistenceResponse) ProtoMessage() {}
func (x *SetNetworkMapPersistenceResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[39]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetNetworkMapPersistenceResponse.ProtoReflect.Descriptor instead.
func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{39}
}
var File_daemon_proto protoreflect.FileDescriptor var File_daemon_proto protoreflect.FileDescriptor
var file_daemon_proto_rawDesc = []byte{ var file_daemon_proto_rawDesc = []byte{
@@ -2399,66 +2827,116 @@ var file_daemon_proto_rawDesc = []byte{
0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74,
0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d,
0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a,
0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73,
0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74,
0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71,
0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e,
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08,
0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63,
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74,
0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74,
0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02,
0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c,
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65,
0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e,
0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61,
0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f,
0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65,
0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c,
0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05,
0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52,
0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04,
0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10,
0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x81, 0x09, 0x0a,
0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36,
0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53,
0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69,
0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a,
0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44,
0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66,
0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f,
0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69,
0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65,
0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74,
0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62,
0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c,
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65,
0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a,
0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53,
0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74,
0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45,
0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53,
0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65,
0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61,
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71,
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69,
0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x6f, 0x33,
0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47,
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c,
0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67,
0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42,
0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
} }
var ( var (
@@ -2474,50 +2952,59 @@ func file_daemon_proto_rawDescGZIP() []byte {
} }
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 32) var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41)
var file_daemon_proto_goTypes = []interface{}{ var file_daemon_proto_goTypes = []interface{}{
(LogLevel)(0), // 0: daemon.LogLevel (LogLevel)(0), // 0: daemon.LogLevel
(*LoginRequest)(nil), // 1: daemon.LoginRequest (*LoginRequest)(nil), // 1: daemon.LoginRequest
(*LoginResponse)(nil), // 2: daemon.LoginResponse (*LoginResponse)(nil), // 2: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest (*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse (*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 5: daemon.UpRequest (*UpRequest)(nil), // 5: daemon.UpRequest
(*UpResponse)(nil), // 6: daemon.UpResponse (*UpResponse)(nil), // 6: daemon.UpResponse
(*StatusRequest)(nil), // 7: daemon.StatusRequest (*StatusRequest)(nil), // 7: daemon.StatusRequest
(*StatusResponse)(nil), // 8: daemon.StatusResponse (*StatusResponse)(nil), // 8: daemon.StatusResponse
(*DownRequest)(nil), // 9: daemon.DownRequest (*DownRequest)(nil), // 9: daemon.DownRequest
(*DownResponse)(nil), // 10: daemon.DownResponse (*DownResponse)(nil), // 10: daemon.DownResponse
(*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest (*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse (*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
(*PeerState)(nil), // 13: daemon.PeerState (*PeerState)(nil), // 13: daemon.PeerState
(*LocalPeerState)(nil), // 14: daemon.LocalPeerState (*LocalPeerState)(nil), // 14: daemon.LocalPeerState
(*SignalState)(nil), // 15: daemon.SignalState (*SignalState)(nil), // 15: daemon.SignalState
(*ManagementState)(nil), // 16: daemon.ManagementState (*ManagementState)(nil), // 16: daemon.ManagementState
(*RelayState)(nil), // 17: daemon.RelayState (*RelayState)(nil), // 17: daemon.RelayState
(*NSGroupState)(nil), // 18: daemon.NSGroupState (*NSGroupState)(nil), // 18: daemon.NSGroupState
(*FullStatus)(nil), // 19: daemon.FullStatus (*FullStatus)(nil), // 19: daemon.FullStatus
(*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
(*IPList)(nil), // 24: daemon.IPList (*IPList)(nil), // 24: daemon.IPList
(*Route)(nil), // 25: daemon.Route (*Route)(nil), // 25: daemon.Route
(*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse
(*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest
(*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse (*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse
(*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest (*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse (*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse
nil, // 32: daemon.Route.ResolvedIPsEntry (*State)(nil), // 32: daemon.State
(*durationpb.Duration)(nil), // 33: google.protobuf.Duration (*ListStatesRequest)(nil), // 33: daemon.ListStatesRequest
(*timestamppb.Timestamp)(nil), // 34: google.protobuf.Timestamp (*ListStatesResponse)(nil), // 34: daemon.ListStatesResponse
(*CleanStateRequest)(nil), // 35: daemon.CleanStateRequest
(*CleanStateResponse)(nil), // 36: daemon.CleanStateResponse
(*DeleteStateRequest)(nil), // 37: daemon.DeleteStateRequest
(*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse
(*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest
(*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse
nil, // 41: daemon.Route.ResolvedIPsEntry
(*durationpb.Duration)(nil), // 42: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp
} }
var file_daemon_proto_depIdxs = []int32{ var file_daemon_proto_depIdxs = []int32{
33, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
34, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp 43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
34, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp 43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
33, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -2525,39 +3012,48 @@ var file_daemon_proto_depIdxs = []int32{
17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState
18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route 25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
32, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry 41, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry
0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
24, // 15: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State
1, // 16: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 24, // 16: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList
3, // 17: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest 1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
5, // 18: daemon.DaemonService.Up:input_type -> daemon.UpRequest 3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
7, // 19: daemon.DaemonService.Status:input_type -> daemon.StatusRequest 5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest
9, // 20: daemon.DaemonService.Down:input_type -> daemon.DownRequest 7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
11, // 21: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest 9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest
20, // 22: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest 11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
22, // 23: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest 20, // 23: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
22, // 24: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest 22, // 24: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
26, // 25: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest 22, // 25: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
28, // 26: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest 26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
30, // 27: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest 28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
2, // 28: daemon.DaemonService.Login:output_type -> daemon.LoginResponse 30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
4, // 29: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse 33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
6, // 30: daemon.DaemonService.Up:output_type -> daemon.UpResponse 35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
8, // 31: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
10, // 32: daemon.DaemonService.Down:output_type -> daemon.DownResponse 39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest
12, // 33: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse 2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
21, // 34: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse 4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
23, // 35: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse 6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse
23, // 36: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse 8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
27, // 37: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse
29, // 38: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
31, // 39: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse 21, // 39: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
28, // [28:40] is the sub-list for method output_type 23, // 40: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
16, // [16:28] is the sub-list for method input_type 23, // 41: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
16, // [16:16] is the sub-list for extension type_name 27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
16, // [16:16] is the sub-list for extension extendee 29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
0, // [0:16] is the sub-list for field type_name 31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
33, // [33:49] is the sub-list for method output_type
17, // [17:33] is the sub-list for method input_type
17, // [17:17] is the sub-list for extension type_name
17, // [17:17] is the sub-list for extension extendee
0, // [0:17] is the sub-list for field type_name
} }
func init() { file_daemon_proto_init() } func init() { file_daemon_proto_init() }
@@ -2938,6 +3434,114 @@ func file_daemon_proto_init() {
return nil return nil
} }
} }
file_daemon_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*State); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListStatesRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListStatesResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CleanStateRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CleanStateResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DeleteStateRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DeleteStateResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetNetworkMapPersistenceRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetNetworkMapPersistenceResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{} file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
type x struct{} type x struct{}
@@ -2946,7 +3550,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_daemon_proto_rawDesc, RawDescriptor: file_daemon_proto_rawDesc,
NumEnums: 1, NumEnums: 1,
NumMessages: 32, NumMessages: 41,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@@ -45,7 +45,20 @@ service DaemonService {
// SetLogLevel sets the log level of the daemon // SetLogLevel sets the log level of the daemon
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {} rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
};
// List all states
rpc ListStates(ListStatesRequest) returns (ListStatesResponse) {}
// Clean specific state or all states
rpc CleanState(CleanStateRequest) returns (CleanStateResponse) {}
// Delete specific state or all states
rpc DeleteState(DeleteStateRequest) returns (DeleteStateResponse) {}
// SetNetworkMapPersistence enables or disables network map persistence
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
}
message LoginRequest { message LoginRequest {
// setupKey wiretrustee setup key. // setupKey wiretrustee setup key.
@@ -293,4 +306,46 @@ message SetLogLevelRequest {
} }
message SetLogLevelResponse { message SetLogLevelResponse {
} }
// State represents a daemon state entry
message State {
string name = 1;
}
// ListStatesRequest is empty as it requires no parameters
message ListStatesRequest {}
// ListStatesResponse contains a list of states
message ListStatesResponse {
repeated State states = 1;
}
// CleanStateRequest for cleaning states
message CleanStateRequest {
string state_name = 1;
bool all = 2;
}
// CleanStateResponse contains the result of the clean operation
message CleanStateResponse {
int32 cleaned_states = 1;
}
// DeleteStateRequest for deleting states
message DeleteStateRequest {
string state_name = 1;
bool all = 2;
}
// DeleteStateResponse contains the result of the delete operation
message DeleteStateResponse {
int32 deleted_states = 1;
}
message SetNetworkMapPersistenceRequest {
bool enabled = 1;
}
message SetNetworkMapPersistenceResponse {}

View File

@@ -43,6 +43,14 @@ type DaemonServiceClient interface {
GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error)
// SetLogLevel sets the log level of the daemon // SetLogLevel sets the log level of the daemon
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
// List all states
ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error)
// Clean specific state or all states
CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error)
// Delete specific state or all states
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
// SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
} }
type daemonServiceClient struct { type daemonServiceClient struct {
@@ -161,6 +169,42 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe
return out, nil return out, nil
} }
func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) {
out := new(ListStatesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) {
out := new(CleanStateResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) {
out := new(DeleteStateResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) {
out := new(SetNetworkMapPersistenceResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service. // DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer // All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility // for forward compatibility
@@ -190,6 +234,14 @@ type DaemonServiceServer interface {
GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error) GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error)
// SetLogLevel sets the log level of the daemon // SetLogLevel sets the log level of the daemon
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
// List all states
ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error)
// Clean specific state or all states
CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error)
// Delete specific state or all states
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
// SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
mustEmbedUnimplementedDaemonServiceServer() mustEmbedUnimplementedDaemonServiceServer()
} }
@@ -233,6 +285,18 @@ func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLeve
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) { func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented") return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
} }
func (UnimplementedDaemonServiceServer) ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListStates not implemented")
}
func (UnimplementedDaemonServiceServer) CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method CleanState not implemented")
}
func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DeleteState not implemented")
}
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -462,6 +526,78 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListStatesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).ListStates(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ListStates",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CleanStateRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).CleanState(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/CleanState",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DeleteStateRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DeleteState(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DeleteState",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SetNetworkMapPersistenceRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -517,6 +653,22 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "SetLogLevel", MethodName: "SetLogLevel",
Handler: _DaemonService_SetLogLevel_Handler, Handler: _DaemonService_SetLogLevel_Handler,
}, },
{
MethodName: "ListStates",
Handler: _DaemonService_ListStates_Handler,
},
{
MethodName: "CleanState",
Handler: _DaemonService_CleanState_Handler,
},
{
MethodName: "DeleteState",
Handler: _DaemonService_DeleteState_Handler,
},
{
MethodName: "SetNetworkMapPersistence",
Handler: _DaemonService_SetNetworkMapPersistence_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "daemon.proto", Metadata: "daemon.proto",

View File

@@ -5,32 +5,44 @@ package server
import ( import (
"archive/zip" "archive/zip"
"bufio" "bufio"
"bytes"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"path/filepath"
"sort" "sort"
"strings" "strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/management/proto"
) )
const readmeContent = `Netbird debug bundle const readmeContent = `Netbird debug bundle
This debug bundle contains the following files: This debug bundle contains the following files:
status.txt: Anonymized status information of the NetBird client. status.txt: Anonymized status information of the NetBird client.
client.log: Most recent, anonymized log file of the NetBird client. client.log: Most recent, anonymized client log file of the NetBird client.
netbird.err: Most recent, anonymized stderr log file of the NetBird client.
netbird.out: Most recent, anonymized stdout log file of the NetBird client.
routes.txt: Anonymized system routes, if --system-info flag was provided. routes.txt: Anonymized system routes, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
config.txt: Anonymized configuration information of the NetBird client. config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states.
Anonymization Process Anonymization Process
@@ -50,8 +62,32 @@ Domains
All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle. All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
Reoccuring domain names are replaced with the same anonymized domain. Reoccuring domain names are replaced with the same anonymized domain.
Network Map
The network_map.json file contains the following anonymized information:
- Peer configurations (addresses, FQDNs, DNS settings)
- Remote and offline peer information (allowed IPs, FQDNs)
- Routes (network ranges, associated domains)
- DNS configuration (nameservers, domains, custom zones)
- Firewall rules (peer IPs, source/destination ranges)
SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above.
State File
The state.json file contains anonymized internal state information of the NetBird client, including:
- DNS settings and configuration
- Firewall rules
- Exclusion routes
- Route selection
- Other internal states that may be present
The state file follows the same anonymization rules as other files:
- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure
- Domain names are consistently anonymized
- Technical identifiers and non-sensitive data remain unchanged
Routes Routes
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
Network Interfaces Network Interfaces
The interfaces.txt file contains information about network interfaces, including: The interfaces.txt file contains information about network interfaces, including:
- Interface name - Interface name
@@ -72,6 +108,12 @@ The config.txt file contains anonymized configuration information of the NetBird
Other non-sensitive configuration options are included without anonymization. Other non-sensitive configuration options are included without anonymization.
` `
const (
clientLogFile = "client.log"
errorLogFile = "netbird.err"
stdoutLogFile = "netbird.out"
)
// DebugBundle creates a debug bundle and returns the location. // DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock() s.mutex.Lock()
@@ -119,19 +161,27 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
seedFromStatus(anonymizer, &status) seedFromStatus(anonymizer, &status)
if err := s.addConfig(req, anonymizer, archive); err != nil { if err := s.addConfig(req, anonymizer, archive); err != nil {
return fmt.Errorf("add config: %w", err) log.Errorf("Failed to add config to debug bundle: %v", err)
} }
if req.GetSystemInfo() { if req.GetSystemInfo() {
if err := s.addRoutes(req, anonymizer, archive); err != nil { if err := s.addRoutes(req, anonymizer, archive); err != nil {
return fmt.Errorf("add routes: %w", err) log.Errorf("Failed to add routes to debug bundle: %v", err)
} }
if err := s.addInterfaces(req, anonymizer, archive); err != nil { if err := s.addInterfaces(req, anonymizer, archive); err != nil {
return fmt.Errorf("add interfaces: %w", err) log.Errorf("Failed to add interfaces to debug bundle: %v", err)
} }
} }
if err := s.addNetworkMap(req, anonymizer, archive); err != nil {
return fmt.Errorf("add network map: %w", err)
}
if err := s.addStateFile(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add state file to debug bundle: %v", err)
}
if err := s.addLogfile(req, anonymizer, archive); err != nil { if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err) return fmt.Errorf("add log file: %w", err)
} }
@@ -220,15 +270,16 @@ func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
} }
func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
if routes, err := systemops.GetRoutesFromTable(); err != nil { routes, err := systemops.GetRoutesFromTable()
log.Errorf("Failed to get routes: %v", err) if err != nil {
} else { return fmt.Errorf("get routes: %w", err)
// TODO: get routes including nexthop }
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
routesReader := strings.NewReader(routesContent) // TODO: get routes including nexthop
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil { routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
return fmt.Errorf("add routes file to zip: %w", err) routesReader := strings.NewReader(routesContent)
} if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
} }
return nil return nil
} }
@@ -248,14 +299,106 @@ func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonym
return nil return nil
} }
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) { func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
logFile, err := os.Open(s.logFile) networkMap, err := s.getLatestNetworkMap()
if err != nil { if err != nil {
return fmt.Errorf("open log file: %w", err) // Skip if network map is not available, but log it
log.Debugf("skipping empty network map in debug bundle: %v", err)
return nil
}
if req.GetAnonymize() {
if err := anonymizeNetworkMap(networkMap, anonymizer); err != nil {
return fmt.Errorf("anonymize network map: %w", err)
}
}
options := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
Indent: " ",
AllowPartial: true,
}
jsonBytes, err := options.Marshal(networkMap)
if err != nil {
return fmt.Errorf("generate json: %w", err)
}
if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "network_map.json"); err != nil {
return fmt.Errorf("add network map to zip: %w", err)
}
return nil
}
func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
path := statemanager.GetDefaultStatePath()
if path == "" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return fmt.Errorf("read state file: %w", err)
}
if req.GetAnonymize() {
var rawStates map[string]json.RawMessage
if err := json.Unmarshal(data, &rawStates); err != nil {
return fmt.Errorf("unmarshal states: %w", err)
}
if err := anonymizeStateFile(&rawStates, anonymizer); err != nil {
return fmt.Errorf("anonymize state file: %w", err)
}
bs, err := json.MarshalIndent(rawStates, "", " ")
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}
data = bs
}
if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil {
return fmt.Errorf("add state file to zip: %w", err)
}
return nil
}
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
logDir := filepath.Dir(s.logFile)
if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil {
return fmt.Errorf("add client log file to zip: %w", err)
}
errLogPath := filepath.Join(logDir, errorLogFile)
if err := s.addSingleLogfile(errLogPath, errorLogFile, req, anonymizer, archive); err != nil {
log.Warnf("Failed to add %s to zip: %v", errorLogFile, err)
}
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil {
log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err)
}
return nil
}
// addSingleLogfile adds a single log file to the archive
func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
logFile, err := os.Open(logPath)
if err != nil {
return fmt.Errorf("open log file %s: %w", targetName, err)
} }
defer func() { defer func() {
if err := logFile.Close(); err != nil { if err := logFile.Close(); err != nil {
log.Errorf("Failed to close original log file: %v", err) log.Errorf("Failed to close log file %s: %v", targetName, err)
} }
}() }()
@@ -264,45 +407,55 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize
var writer *io.PipeWriter var writer *io.PipeWriter
logReader, writer = io.Pipe() logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, anonymizer) go anonymizeLog(logFile, writer, anonymizer)
} else { } else {
logReader = logFile logReader = logFile
} }
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
return fmt.Errorf("add log file to zip: %w", err) if err := addFileToZip(archive, logReader, targetName); err != nil {
return fmt.Errorf("add %s to zip: %w", targetName, err)
} }
return nil return nil
} }
func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { // getLatestNetworkMap returns the latest network map from the engine if network map persistence is enabled
defer func() { func (s *Server) getLatestNetworkMap() (*mgmProto.NetworkMap, error) {
// always nil if s.connectClient == nil {
_ = writer.Close() return nil, errors.New("connect client is not initialized")
}() }
scanner := bufio.NewScanner(reader) engine := s.connectClient.Engine()
for scanner.Scan() { if engine == nil {
line := anonymizer.AnonymizeString(scanner.Text()) return nil, errors.New("engine is not initialized")
if _, err := writer.Write([]byte(line + "\n")); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
return
}
} }
if err := scanner.Err(); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) networkMap, err := engine.GetLatestNetworkMap()
return if err != nil {
return nil, fmt.Errorf("get latest network map: %w", err)
} }
if networkMap == nil {
return nil, errors.New("network map is not available")
}
return networkMap, nil
} }
// GetLogLevel gets the current logging level for the server. // GetLogLevel gets the current logging level for the server.
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) { func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
level := ParseLogLevel(log.GetLevel().String()) level := ParseLogLevel(log.GetLevel().String())
return &proto.GetLogLevelResponse{Level: level}, nil return &proto.GetLogLevelResponse{Level: level}, nil
} }
// SetLogLevel sets the logging level for the server. // SetLogLevel sets the logging level for the server.
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) { func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
level, err := log.ParseLevel(req.Level.String()) level, err := log.ParseLevel(req.Level.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid log level: %w", err) return nil, fmt.Errorf("invalid log level: %w", err)
@@ -313,6 +466,20 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
return &proto.SetLogLevelResponse{}, nil return &proto.SetLogLevelResponse{}, nil
} }
// SetNetworkMapPersistence sets the network map persistence for the server.
func (s *Server) SetNetworkMapPersistence(_ context.Context, req *proto.SetNetworkMapPersistenceRequest) (*proto.SetNetworkMapPersistenceResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
enabled := req.GetEnabled()
s.persistNetworkMap = enabled
if s.connectClient != nil {
s.connectClient.SetNetworkMapPersistence(enabled)
}
return &proto.SetNetworkMapPersistenceResponse{}, nil
}
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error { func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{ header := &zip.FileHeader{
Name: filename, Name: filename,
@@ -458,6 +625,26 @@ func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *an
return builder.String() return builder.String()
} }
func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
defer func() {
// always nil
_ = writer.Close()
}()
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
return
}
}
if err := scanner.Err(); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
return
}
}
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string { func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
anonymizedIPs := make([]string, len(ips)) anonymizedIPs := make([]string, len(ips))
for i, ip := range ips { for i, ip := range ips {
@@ -484,3 +671,248 @@ func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []s
} }
return anonymizedIPs return anonymizedIPs
} }
func anonymizeNetworkMap(networkMap *mgmProto.NetworkMap, anonymizer *anonymize.Anonymizer) error {
if networkMap.PeerConfig != nil {
anonymizePeerConfig(networkMap.PeerConfig, anonymizer)
}
for _, peer := range networkMap.RemotePeers {
anonymizeRemotePeer(peer, anonymizer)
}
for _, peer := range networkMap.OfflinePeers {
anonymizeRemotePeer(peer, anonymizer)
}
for _, r := range networkMap.Routes {
anonymizeRoute(r, anonymizer)
}
if networkMap.DNSConfig != nil {
anonymizeDNSConfig(networkMap.DNSConfig, anonymizer)
}
for _, rule := range networkMap.FirewallRules {
anonymizeFirewallRule(rule, anonymizer)
}
for _, rule := range networkMap.RoutesFirewallRules {
anonymizeRouteFirewallRule(rule, anonymizer)
}
return nil
}
func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anonymizer) {
if config == nil {
return
}
if addr, err := netip.ParseAddr(config.Address); err == nil {
config.Address = anonymizer.AnonymizeIP(addr).String()
}
if config.SshConfig != nil && len(config.SshConfig.SshPubKey) > 0 {
config.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
}
config.Dns = anonymizer.AnonymizeString(config.Dns)
config.Fqdn = anonymizer.AnonymizeDomain(config.Fqdn)
}
func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize.Anonymizer) {
if peer == nil {
return
}
for i, ip := range peer.AllowedIps {
// Try to parse as prefix first (CIDR)
if prefix, err := netip.ParsePrefix(ip); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
} else if addr, err := netip.ParseAddr(ip); err == nil {
peer.AllowedIps[i] = anonymizer.AnonymizeIP(addr).String()
}
}
peer.Fqdn = anonymizer.AnonymizeDomain(peer.Fqdn)
if peer.SshConfig != nil && len(peer.SshConfig.SshPubKey) > 0 {
peer.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
}
}
func anonymizeRoute(route *mgmProto.Route, anonymizer *anonymize.Anonymizer) {
if route == nil {
return
}
if prefix, err := netip.ParsePrefix(route.Network); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
route.Network = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
for i, domain := range route.Domains {
route.Domains[i] = anonymizer.AnonymizeDomain(domain)
}
route.NetID = anonymizer.AnonymizeString(route.NetID)
}
func anonymizeDNSConfig(config *mgmProto.DNSConfig, anonymizer *anonymize.Anonymizer) {
if config == nil {
return
}
anonymizeNameServerGroups(config.NameServerGroups, anonymizer)
anonymizeCustomZones(config.CustomZones, anonymizer)
}
func anonymizeNameServerGroups(groups []*mgmProto.NameServerGroup, anonymizer *anonymize.Anonymizer) {
for _, group := range groups {
anonymizeServers(group.NameServers, anonymizer)
anonymizeDomains(group.Domains, anonymizer)
}
}
func anonymizeServers(servers []*mgmProto.NameServer, anonymizer *anonymize.Anonymizer) {
for _, server := range servers {
if addr, err := netip.ParseAddr(server.IP); err == nil {
server.IP = anonymizer.AnonymizeIP(addr).String()
}
}
}
func anonymizeDomains(domains []string, anonymizer *anonymize.Anonymizer) {
for i, domain := range domains {
domains[i] = anonymizer.AnonymizeDomain(domain)
}
}
func anonymizeCustomZones(zones []*mgmProto.CustomZone, anonymizer *anonymize.Anonymizer) {
for _, zone := range zones {
zone.Domain = anonymizer.AnonymizeDomain(zone.Domain)
anonymizeRecords(zone.Records, anonymizer)
}
}
func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
for _, record := range records {
record.Name = anonymizer.AnonymizeDomain(record.Name)
anonymizeRData(record, anonymizer)
}
}
func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
switch record.Type {
case 1, 28: // A or AAAA record
if addr, err := netip.ParseAddr(record.RData); err == nil {
record.RData = anonymizer.AnonymizeIP(addr).String()
}
default:
record.RData = anonymizer.AnonymizeString(record.RData)
}
}
func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.Anonymizer) {
if rule == nil {
return
}
if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
}
}
func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *anonymize.Anonymizer) {
if rule == nil {
return
}
for i, sourceRange := range rule.SourceRanges {
if prefix, err := netip.ParsePrefix(sourceRange); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
rule.SourceRanges[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
}
if prefix, err := netip.ParsePrefix(rule.Destination); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
}
func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error {
for name, rawState := range *rawStates {
if string(rawState) == "null" {
continue
}
var state map[string]any
if err := json.Unmarshal(rawState, &state); err != nil {
return fmt.Errorf("unmarshal state %s: %w", name, err)
}
state = anonymizeValue(state, anonymizer).(map[string]any)
bs, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal state %s: %w", name, err)
}
(*rawStates)[name] = bs
}
return nil
}
func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any {
switch v := value.(type) {
case string:
return anonymizeString(v, anonymizer)
case map[string]any:
return anonymizeMap(v, anonymizer)
case []any:
return anonymizeSlice(v, anonymizer)
}
return value
}
func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string {
if prefix, err := netip.ParsePrefix(v); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
if ip, err := netip.ParseAddr(v); err == nil {
return anonymizer.AnonymizeIP(ip).String()
}
return anonymizer.AnonymizeString(v)
}
func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any {
result := make(map[string]any, len(v))
for key, val := range v {
newKey := anonymizeMapKey(key, anonymizer)
result[newKey] = anonymizeValue(val, anonymizer)
}
return result
}
func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string {
if prefix, err := netip.ParsePrefix(key); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
if ip, err := netip.ParseAddr(key); err == nil {
return anonymizer.AnonymizeIP(ip).String()
}
return key
}
func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any {
for i, val := range v {
v[i] = anonymizeValue(val, anonymizer)
}
return v
}

430
client/server/debug_test.go Normal file
View File

@@ -0,0 +1,430 @@
package server
import (
"encoding/json"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func TestAnonymizeStateFile(t *testing.T) {
testState := map[string]json.RawMessage{
"null_state": json.RawMessage("null"),
"test_state": mustMarshal(map[string]any{
// Test simple fields
"public_ip": "203.0.113.1",
"private_ip": "192.168.1.1",
"protected_ip": "100.64.0.1",
"well_known_ip": "8.8.8.8",
"ipv6_addr": "2001:db8::1",
"private_ipv6": "fd00::1",
"domain": "test.example.com",
"uri": "stun:stun.example.com:3478",
"uri_with_ip": "turn:203.0.113.1:3478",
"netbird_domain": "device.netbird.cloud",
// Test CIDR ranges
"public_cidr": "203.0.113.0/24",
"private_cidr": "192.168.0.0/16",
"protected_cidr": "100.64.0.0/10",
"ipv6_cidr": "2001:db8::/32",
"private_ipv6_cidr": "fd00::/8",
// Test nested structures
"nested": map[string]any{
"ip": "203.0.113.2",
"domain": "nested.example.com",
"more_nest": map[string]any{
"ip": "203.0.113.3",
"domain": "deep.example.com",
},
},
// Test arrays
"string_array": []any{
"203.0.113.4",
"test1.example.com",
"test2.example.com",
},
"object_array": []any{
map[string]any{
"ip": "203.0.113.5",
"domain": "array1.example.com",
},
map[string]any{
"ip": "203.0.113.6",
"domain": "array2.example.com",
},
},
// Test multiple occurrences of same value
"duplicate_ip": "203.0.113.1", // Same as public_ip
"duplicate_domain": "test.example.com", // Same as domain
// Test URIs with various schemes
"stun_uri": "stun:stun.example.com:3478",
"turns_uri": "turns:turns.example.com:5349",
"http_uri": "http://web.example.com:80",
"https_uri": "https://secure.example.com:443",
// Test strings that might look like IPs but aren't
"not_ip": "300.300.300.300",
"partial_ip": "192.168",
"ip_like_string": "1234.5678",
// Test mixed content strings
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
// Test empty and special values
"empty_string": "",
"null_value": nil,
"numeric_value": 42,
"boolean_value": true,
}),
"route_state": mustMarshal(map[string]any{
"routes": []any{
map[string]any{
"network": "203.0.113.0/24",
"gateway": "203.0.113.1",
"domains": []any{
"route1.example.com",
"route2.example.com",
},
},
map[string]any{
"network": "2001:db8::/32",
"gateway": "2001:db8::1",
"domains": []any{
"route3.example.com",
"route4.example.com",
},
},
},
// Test map with IP/CIDR keys
"refCountMap": map[string]any{
"203.0.113.1/32": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"2001:db8::1/128": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "fe80::1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
},
},
},
}),
}
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Pre-seed the domains we need to verify in the test assertions
anonymizer.AnonymizeDomain("test.example.com")
anonymizer.AnonymizeDomain("nested.example.com")
anonymizer.AnonymizeDomain("deep.example.com")
anonymizer.AnonymizeDomain("array1.example.com")
err := anonymizeStateFile(&testState, anonymizer)
require.NoError(t, err)
// Helper function to unmarshal and get nested values
var state map[string]any
err = json.Unmarshal(testState["test_state"], &state)
require.NoError(t, err)
// Test null state remains unchanged
require.Equal(t, "null", string(testState["null_state"]))
// Basic assertions
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
assert.NotEqual(t, "test.example.com", state["domain"])
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
// CIDR ranges
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
// Nested structures
nested := state["nested"].(map[string]any)
assert.NotEqual(t, "203.0.113.2", nested["ip"])
assert.NotEqual(t, "nested.example.com", nested["domain"])
moreNest := nested["more_nest"].(map[string]any)
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
// Arrays
strArray := state["string_array"].([]any)
assert.NotEqual(t, "203.0.113.4", strArray[0])
assert.NotEqual(t, "test1.example.com", strArray[1])
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
objArray := state["object_array"].([]any)
firstObj := objArray[0].(map[string]any)
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
// Duplicate values should be anonymized consistently
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
assert.Equal(t, state["domain"], state["duplicate_domain"])
// URIs
assert.NotContains(t, state["stun_uri"], "stun.example.com")
assert.NotContains(t, state["turns_uri"], "turns.example.com")
assert.NotContains(t, state["http_uri"], "web.example.com")
assert.NotContains(t, state["https_uri"], "secure.example.com")
// Non-IP strings should remain unchanged
assert.Equal(t, "300.300.300.300", state["not_ip"])
assert.Equal(t, "192.168", state["partial_ip"])
assert.Equal(t, "1234.5678", state["ip_like_string"])
// Mixed content should have IPs and domains replaced
mixedContent := state["mixed_content"].(string)
assert.NotContains(t, mixedContent, "203.0.113.1")
assert.NotContains(t, mixedContent, "test.example.com")
assert.Contains(t, mixedContent, "Server at ")
assert.Contains(t, mixedContent, " on port 80")
// Special values should remain unchanged
assert.Equal(t, "", state["empty_string"])
assert.Nil(t, state["null_value"])
assert.Equal(t, float64(42), state["numeric_value"])
assert.Equal(t, true, state["boolean_value"])
// Check route state
var routeState map[string]any
err = json.Unmarshal(testState["route_state"], &routeState)
require.NoError(t, err)
routes := routeState["routes"].([]any)
route1 := routes[0].(map[string]any)
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
assert.Contains(t, route1["network"], "/24")
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
domains := route1["domains"].([]any)
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
// Check map keys are anonymized
refCountMap := routeState["refCountMap"].(map[string]any)
hasPublicIPKey := false
hasIPv6Key := false
hasPrivateIPKey := false
for key := range refCountMap {
if strings.Contains(key, "203.0.113.1") {
hasPublicIPKey = true
}
if strings.Contains(key, "2001:db8::1") {
hasIPv6Key = true
}
if key == "10.0.0.1/32" {
hasPrivateIPKey = true
}
}
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
}
func mustMarshal(v any) json.RawMessage {
data, err := json.Marshal(v)
if err != nil {
panic(err)
}
return data
}
func TestAnonymizeNetworkMap(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{
Address: "203.0.113.5",
Dns: "1.2.3.4",
Fqdn: "peer1.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
},
},
RemotePeers: []*mgmProto.RemotePeerConfig{
{
AllowedIps: []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
},
Fqdn: "peer2.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."),
},
},
},
Routes: []*mgmProto.Route{
{
Network: "197.51.100.0/24",
Domains: []string{"prod.example.com", "staging.example.com"},
NetID: "net-123abc",
},
},
DNSConfig: &mgmProto.DNSConfig{
NameServerGroups: []*mgmProto.NameServerGroup{
{
NameServers: []*mgmProto.NameServer{
{IP: "8.8.8.8"},
{IP: "1.1.1.1"},
{IP: "203.0.113.53"},
},
Domains: []string{"example.com", "internal.example.com"},
},
},
CustomZones: []*mgmProto.CustomZone{
{
Domain: "custom.example.com",
Records: []*mgmProto.SimpleRecord{
{
Name: "www.custom.example.com",
Type: 1,
RData: "203.0.113.10",
},
{
Name: "internal.custom.example.com",
Type: 1,
RData: "192.168.1.10",
},
},
},
},
},
}
// Create anonymizer with test addresses
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Anonymize the network map
err := anonymizeNetworkMap(networkMap, anonymizer)
require.NoError(t, err)
// Test PeerConfig anonymization
peerCfg := networkMap.PeerConfig
require.NotEqual(t, "203.0.113.5", peerCfg.Address)
// Verify DNS and FQDN are properly anonymized
require.NotEqual(t, "1.2.3.4", peerCfg.Dns)
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
// Verify SSH key is replaced
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
// Test RemotePeers anonymization
remotePeer := networkMap.RemotePeers[0]
// Verify FQDN is anonymized
require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn)
require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain"))
// Check that public IPs are anonymized but private IPs are preserved
for _, allowedIP := range remotePeer.AllowedIps {
ip, _, err := net.ParseCIDR(allowedIP)
require.NoError(t, err)
if ip.IsPrivate() || isInCGNATRange(ip) {
require.Contains(t, []string{
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
}, allowedIP)
} else {
require.NotContains(t, []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
}, allowedIP)
}
}
// Test Routes anonymization
route := networkMap.Routes[0]
require.NotEqual(t, "197.51.100.0/24", route.Network)
for _, domain := range route.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test DNS config anonymization
dnsConfig := networkMap.DNSConfig
nameServerGroup := dnsConfig.NameServerGroups[0]
// Verify well-known DNS servers are preserved
require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP)
require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP)
// Verify public DNS server is anonymized
require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP)
// Verify domains are anonymized
for _, domain := range nameServerGroup.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test CustomZones anonymization
customZone := dnsConfig.CustomZones[0]
require.True(t, strings.HasSuffix(customZone.Domain, ".domain"))
require.NotContains(t, customZone.Domain, "example.com")
// Verify records are properly anonymized
for _, record := range customZone.Records {
require.True(t, strings.HasSuffix(record.Name, ".domain"))
require.NotContains(t, record.Name, "example.com")
ip := net.ParseIP(record.RData)
if ip != nil {
if !ip.IsPrivate() {
require.NotEqual(t, "203.0.113.10", record.RData)
} else {
require.Equal(t, "192.168.1.10", record.RData)
}
}
}
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
return cgnat.Contains(ip)
}

View File

@@ -68,6 +68,8 @@ type Server struct {
relayProbe *internal.Probe relayProbe *internal.Probe
wgProbe *internal.Probe wgProbe *internal.Probe
lastProbe time.Time lastProbe time.Time
persistNetworkMap bool
} }
type oauthAuthFlow struct { type oauthAuthFlow struct {
@@ -196,6 +198,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error { runOperation := func() error {
log.Tracef("running client connection") log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetNetworkMapPersistence(s.persistNetworkMap)
probes := internal.ProbeHolder{ probes := internal.ProbeHolder{
MgmProbe: s.mgmProbe, MgmProbe: s.mgmProbe,

View File

@@ -5,12 +5,112 @@ import (
"fmt" "fmt"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/proto"
) )
// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. // ListStates returns a list of all saved states
func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*proto.ListStatesResponse, error) {
mgr := statemanager.New(statemanager.GetDefaultStatePath())
stateNames, err := mgr.GetSavedStateNames()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get saved state names: %v", err)
}
states := make([]*proto.State, 0, len(stateNames))
for _, name := range stateNames {
states = append(states, &proto.State{
Name: name,
})
}
return &proto.ListStatesResponse{
States: states,
}, nil
}
// CleanState handles cleaning of states (performing cleanup operations)
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}
if req.All {
// Reuse existing cleanup logic for all states
if err := restoreResidualState(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clean all states: %v", err)
}
// Get count of cleaned states
mgr := statemanager.New(statemanager.GetDefaultStatePath())
stateNames, err := mgr.GetSavedStateNames()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get state count: %v", err)
}
return &proto.CleanStateResponse{
CleanedStates: int32(len(stateNames)),
}, nil
}
// Handle single state cleanup
mgr := statemanager.New(statemanager.GetDefaultStatePath())
registerStates(mgr)
if err := mgr.CleanupStateByName(req.StateName); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clean state %s: %v", req.StateName, err)
}
if err := mgr.PersistState(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err)
}
return &proto.CleanStateResponse{
CleanedStates: 1,
}, nil
}
// DeleteState handles deletion of states without cleanup
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}
mgr := statemanager.New(statemanager.GetDefaultStatePath())
var count int
var err error
if req.All {
count, err = mgr.DeleteAllStates()
} else {
err = mgr.DeleteStateByName(req.StateName)
if err == nil {
count = 1
}
}
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete state: %v", err)
}
// Persist the changes
if err := mgr.PersistState(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err)
}
return &proto.DeleteStateResponse{
DeletedStates: int32(count),
}, nil
}
// restoreResidualState checks if the client was not shut down in a clean way and restores residual if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config. // Otherwise, we might not be able to connect to the management server to retrieve new config.
func restoreResidualState(ctx context.Context) error { func restoreResidualState(ctx context.Context) error {
path := statemanager.GetDefaultStatePath() path := statemanager.GetDefaultStatePath()
@@ -24,6 +124,7 @@ func restoreResidualState(ctx context.Context) error {
registerStates(mgr) registerStates(mgr)
var merr *multierror.Error var merr *multierror.Error
if err := mgr.PerformCleanup(); err != nil { if err := mgr.PerformCleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
} }

View File

@@ -61,6 +61,14 @@ type Info struct {
Files []File // for posture checks Files []File // for posture checks
} }
// StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string { func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx) md, hasMeta := metadata.FromOutgoingContext(ctx)

View File

@@ -10,13 +10,12 @@ import (
"os/exec" "os/exec"
"runtime" "runtime"
"strings" "strings"
"time"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -41,11 +40,10 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err) log.Warnf("failed to discover network addresses: %s", err)
} }
serialNum, prodName, manufacturer := sysInfo() start := time.Now()
si := updateStaticInfo()
env := Environment{ if time.Since(start) > 1*time.Second {
Cloud: detect_cloud.Detect(ctx), log.Warnf("updateStaticInfo took %s", time.Since(start))
Platform: detect_platform.Detect(ctx),
} }
gio := &Info{ gio := &Info{
@@ -57,10 +55,10 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
KernelVersion: release, KernelVersion: release,
NetworkAddresses: addrs, NetworkAddresses: addrs,
SystemSerialNumber: serialNum, SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: prodName, SystemProductName: si.SystemProductName,
SystemManufacturer: manufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: env, Environment: si.Environment,
} }
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()

View File

@@ -1,5 +1,4 @@
//go:build !android //go:build !android
// +build !android
package system package system
@@ -16,30 +15,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/zcalusic/sysinfo" "github.com/zcalusic/sysinfo"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
type SysInfoGetter interface { var (
GetSysInfo() SysInfo // it is override in tests
} getSystemInfo = defaultSysInfoImplementation
)
type SysInfoWrapper struct {
si sysinfo.SysInfo
}
func (s SysInfoWrapper) GetSysInfo() SysInfo {
s.si.GetSysInfo()
return SysInfo{
ChassisSerial: s.si.Chassis.Serial,
ProductSerial: s.si.Product.Serial,
BoardSerial: s.si.Board.Serial,
ProductName: s.si.Product.Name,
BoardName: s.si.Board.Name,
ProductVendor: s.si.Product.Vendor,
}
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
@@ -65,12 +47,10 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err) log.Warnf("failed to discover network addresses: %s", err)
} }
si := SysInfoWrapper{} start := time.Now()
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo()) si := updateStaticInfo()
if time.Since(start) > 1*time.Second {
env := Environment{ log.Warnf("updateStaticInfo took %s", time.Since(start))
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
} }
gio := &Info{ gio := &Info{
@@ -85,10 +65,10 @@ func GetInfo(ctx context.Context) *Info {
UIVersion: extractUserAgent(ctx), UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1], KernelVersion: osInfo[1],
NetworkAddresses: addrs, NetworkAddresses: addrs,
SystemSerialNumber: serialNum, SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: prodName, SystemProductName: si.SystemProductName,
SystemManufacturer: manufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: env, Environment: si.Environment,
} }
return gio return gio
@@ -108,9 +88,9 @@ func _getInfo() string {
return out.String() return out.String()
} }
func sysInfo(si SysInfo) (string, string, string) { func sysInfo() (string, string, string) {
isascii := regexp.MustCompile("^[[:ascii:]]+$") isascii := regexp.MustCompile("^[[:ascii:]]+$")
si := getSystemInfo()
serials := []string{si.ChassisSerial, si.ProductSerial} serials := []string{si.ChassisSerial, si.ProductSerial}
serial := "" serial := ""
@@ -141,3 +121,16 @@ func sysInfo(si SysInfo) (string, string, string) {
} }
return serial, name, manufacturer return serial, name, manufacturer
} }
func defaultSysInfoImplementation() SysInfo {
si := sysinfo.SysInfo{}
si.GetSysInfo()
return SysInfo{
ChassisSerial: si.Chassis.Serial,
ProductSerial: si.Product.Serial,
BoardSerial: si.Board.Serial,
ProductName: si.Product.Name,
BoardName: si.Board.Name,
ProductVendor: si.Product.Vendor,
}
}

View File

@@ -6,13 +6,12 @@ import (
"os" "os"
"runtime" "runtime"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi" "github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -42,24 +41,10 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err) log.Warnf("failed to discover network addresses: %s", err)
} }
serialNum, err := sysNumber() start := time.Now()
if err != nil { si := updateStaticInfo()
log.Warnf("failed to get system serial number: %s", err) if time.Since(start) > 1*time.Second {
} log.Warnf("updateStaticInfo took %s", time.Since(start))
prodName, err := sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err := sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
} }
gio := &Info{ gio := &Info{
@@ -71,10 +56,10 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
KernelVersion: buildVersion, KernelVersion: buildVersion,
NetworkAddresses: addrs, NetworkAddresses: addrs,
SystemSerialNumber: serialNum, SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: prodName, SystemProductName: si.SystemProductName,
SystemManufacturer: manufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: env, Environment: si.Environment,
} }
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()
@@ -85,6 +70,26 @@ func GetInfo(ctx context.Context) *Info {
return gio return gio
} }
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var err error
serialNumber, err = sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
productName, err = sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err = sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
return serialNumber, productName, manufacturer
}
func getOSNameAndVersion() (string, string) { func getOSNameAndVersion() (string, string) {
var dst []Win32_OperatingSystem var dst []Win32_OperatingSystem
query := wmi.CreateQuery(&dst, "") query := wmi.CreateQuery(&dst, "")

View File

@@ -0,0 +1,46 @@
//go:build (linux && !android) || windows || (darwin && !ios)
package system
import (
"context"
"sync"
"time"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
)
var (
staticInfo StaticInfo
once sync.Once
)
func init() {
go func() {
_ = updateStaticInfo()
}()
}
func updateStaticInfo() StaticInfo {
once.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
wg.Done()
}()
go func() {
staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
go func() {
staticInfo.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Wait()
})
return staticInfo
}

View File

@@ -183,7 +183,10 @@ func Test_sysInfo(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo) getSystemInfo = func() SysInfo {
return tt.sysInfo
}
gotSerialNum, gotProdName, gotManufacturer := sysInfo()
if gotSerialNum != tt.wantSerialNum { if gotSerialNum != tt.wantSerialNum {
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum) t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
} }

View File

@@ -572,6 +572,7 @@ func (s *serviceClient) onTrayReady() {
s.update.SetOnUpdateListener(s.onUpdateAvailable) s.update.SetOnUpdateListener(s.onUpdateAvailable)
go func() { go func() {
s.getSrvConfig() s.getSrvConfig()
time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon
for { for {
err := s.updateStatus() err := s.updateStatus()
if err != nil { if err != nil {

7
connprofile/iface.go Normal file
View File

@@ -0,0 +1,7 @@
package connprofile
import "github.com/netbirdio/netbird/client/iface/configurer"
type wgIface interface {
GetAllStat() (map[string]configurer.WGStats, error)
}

164
connprofile/profiler.go Normal file
View File

@@ -0,0 +1,164 @@
package connprofile
import (
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
)
type Profile struct {
NetworkMapUpdate time.Time
OfferSent time.Time
OfferReceived time.Time
WireGuardConfigured time.Time
WireGuardConnected time.Time
}
type ConnProfiler struct {
profiles map[string]*Profile
profilesMu sync.Mutex
wgIface wgIface
wgMu sync.Mutex
}
func NewConnProfiler() *ConnProfiler {
return &ConnProfiler{
profiles: make(map[string]*Profile),
}
}
func (p *ConnProfiler) GetProfiles() map[string]Profile {
p.profilesMu.Lock()
defer p.profilesMu.Unlock()
copiedProfiles := make(map[string]Profile)
for key, profile := range p.profiles {
copiedProfiles[key] = Profile{
NetworkMapUpdate: profile.NetworkMapUpdate,
OfferSent: profile.OfferSent,
OfferReceived: profile.OfferReceived,
WireGuardConfigured: profile.WireGuardConfigured,
WireGuardConnected: profile.WireGuardConnected,
}
}
return copiedProfiles
}
func (p *ConnProfiler) WGInterfaceUP(wgInterface wgIface) {
p.wgMu.Lock()
defer p.wgMu.Unlock()
if p.wgIface != nil {
return
}
p.wgIface = wgInterface
go p.watchHandshakes()
}
func (p *ConnProfiler) NetworkMapUpdate(peerConfigs []*proto.RemotePeerConfig) {
p.profilesMu.Lock()
defer p.profilesMu.Unlock()
for _, peerConfig := range peerConfigs {
profile, ok := p.profiles[peerConfig.WgPubKey]
if ok {
continue
}
profile = &Profile{
NetworkMapUpdate: time.Now(),
}
p.profiles[peerConfig.WgPubKey] = profile
}
}
func (p *ConnProfiler) OfferSent(peerID string) {
p.profilesMu.Lock()
defer p.profilesMu.Unlock()
profile, ok := p.profiles[peerID]
if !ok {
log.Warnf("OfferSent: profile not found for peer %s", peerID)
return
}
if !profile.OfferSent.IsZero() {
return
}
profile.OfferSent = time.Now()
}
func (p *ConnProfiler) OfferAnswerReceived(peerID string) {
p.profilesMu.Lock()
defer p.profilesMu.Unlock()
profile, ok := p.profiles[peerID]
if !ok {
log.Warnf("OfferSent: profile not found for peer %s", peerID)
return
}
if !profile.OfferReceived.IsZero() {
return
}
profile.OfferReceived = time.Now()
}
func (p *ConnProfiler) WireGuardConfigured(peerID string) {
p.profilesMu.Lock()
defer p.profilesMu.Unlock()
profile, ok := p.profiles[peerID]
if !ok {
log.Warnf("OfferSent: profile not found for peer %s", peerID)
return
}
if !profile.WireGuardConfigured.IsZero() {
return
}
profile.WireGuardConfigured = time.Now()
}
func (p *ConnProfiler) watchHandshakes() {
ticker := time.NewTicker(300 * time.Millisecond)
for {
select {
case _ = <-ticker.C:
p.checkHandshakes()
}
}
}
func (p *ConnProfiler) checkHandshakes() {
stats, err := p.wgIface.GetAllStat()
if err != nil {
log.Errorf("watchHandshakes: %v", err)
return
}
p.profilesMu.Lock()
for peerID, profile := range p.profiles {
if !profile.WireGuardConnected.IsZero() {
continue
}
stat, ok := stats[peerID]
if !ok {
continue
}
if stat.LastHandshake.IsZero() {
continue
}
if stat.LastHandshake.Before(time.Now().Add(-100 * time.Hour)) {
continue
}
profile.WireGuardConnected = stat.LastHandshake
}
p.profilesMu.Unlock()
}

46
connprofile/report.go Normal file
View File

@@ -0,0 +1,46 @@
package connprofile
import (
"encoding/json"
"time"
log "github.com/sirupsen/logrus"
)
type Report struct {
NetworkMapUpdate time.Time
OfferSent float64
OfferReceived float64
WireGuardConfigured float64
WireGuardConnected float64
}
func report() {
ticker := time.NewTicker(5 * time.Second)
for {
select {
case _ = <-ticker.C:
printJson()
}
}
}
func printJson() {
profiles := Profiler.GetProfiles()
reports := make(map[string]Report)
for key, profile := range profiles {
reports[key] = Report{
NetworkMapUpdate: profile.NetworkMapUpdate,
OfferSent: profile.OfferSent.Sub(profile.NetworkMapUpdate).Seconds(),
OfferReceived: profile.OfferReceived.Sub(profile.OfferSent).Seconds(),
WireGuardConfigured: profile.WireGuardConfigured.Sub(profile.OfferReceived).Seconds(),
WireGuardConnected: profile.WireGuardConnected.Sub(profile.WireGuardConfigured).Seconds(),
}
}
jsonData, err := json.MarshalIndent(reports, "", " ")
if err != nil {
log.Errorf("failed to marshal profiles: %v", err)
}
log.Infof("profiles: %s", jsonData)
}

10
connprofile/static.go Normal file
View File

@@ -0,0 +1,10 @@
package connprofile
var (
Profiler *ConnProfiler
)
func init() {
Profiler = NewConnProfiler()
go report()
}

18
go.mod
View File

@@ -19,13 +19,13 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2 github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.28.0 golang.org/x/crypto v0.31.0
golang.org/x/sys v0.26.0 golang.org/x/sys v0.28.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.64.1 google.golang.org/grpc v1.64.1
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.2
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
) )
@@ -80,7 +80,7 @@ require (
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
github.com/things-go/go-socks5 v0.0.4 github.com/things-go/go-socks5 v0.0.4
github.com/yusufpapurcu/wmi v1.2.4 github.com/yusufpapurcu/wmi v1.2.4
github.com/zcalusic/sysinfo v1.0.2 github.com/zcalusic/sysinfo v1.1.3
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
go.opentelemetry.io/otel v1.26.0 go.opentelemetry.io/otel v1.26.0
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0
@@ -92,8 +92,8 @@ require (
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/net v0.30.0 golang.org/x/net v0.30.0
golang.org/x/oauth2 v0.19.0 golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.8.0 golang.org/x/sync v0.10.0
golang.org/x/term v0.25.0 golang.org/x/term v0.27.0
google.golang.org/api v0.177.0 google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
@@ -219,12 +219,12 @@ require (
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.19.0 // indirect golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

36
go.sum
View File

@@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g= github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -708,8 +708,8 @@ github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc= github.com/zcalusic/sysinfo v1.1.3 h1:u/AVENkuoikKuIZ4sUEJ6iibpmQP6YpGD8SSMCrqAF0=
github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30= github.com/zcalusic/sysinfo v1.1.3/go.mod h1:NX+qYnWGtJVPV0yWldff9uppNKU4h40hJIRPf/pGLv4=
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
@@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1151,8 +1151,8 @@ google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaE
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U= google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U=
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 h1:AgADTJarZTBqgjiUzRgfaBchgYB3/WFTC80GPwsMcRI= google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
@@ -1189,8 +1189,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -530,7 +530,7 @@ renderCaddyfile() {
{ {
debug debug
servers :80,:443 { servers :80,:443 {
protocols h1 h2c protocols h1 h2c h2 h3
} }
} }
@@ -788,6 +788,7 @@ services:
networks: [ netbird ] networks: [ netbird ]
ports: ports:
- '443:443' - '443:443'
- '443:443/udp'
- '80:80' - '80:80'
- '8080:8080' - '8080:8080'
volumes: volumes:

View File

@@ -42,6 +42,7 @@ import (
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
httpapi "github.com/netbirdio/netbird/management/server/http" httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/metrics"
@@ -257,7 +258,7 @@ var (
return fmt.Errorf("failed creating JWT validator: %v", err) return fmt.Errorf("failed creating JWT validator: %v", err)
} }
httpAPIAuthCfg := httpapi.AuthCfg{ httpAPIAuthCfg := configs.AuthCfg{
Issuer: config.HttpConfig.AuthIssuer, Issuer: config.HttpConfig.AuthIssuer,
Audience: config.HttpConfig.AuthAudience, Audience: config.HttpConfig.AuthAudience,
UserIDClaim: config.HttpConfig.AuthUserIDClaim, UserIDClaim: config.HttpConfig.AuthUserIDClaim,

View File

@@ -110,11 +110,10 @@ type AccountManager interface {
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@@ -140,7 +139,7 @@ type AccountManager interface {
HasConnectedChannel(peerID string) bool HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager GetIdpManager() idp.Manager
@@ -966,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
} }
// UserGroupsAddToPeers adds groups to all peers of user // UserGroupsAddToPeers adds groups to all peers of user
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
userPeers := make(map[string]struct{}) userPeers := make(map[string]struct{})
for pid, peer := range a.Peers { for pid, peer := range a.Peers {
if peer.UserID == userID { if peer.UserID == userID {
@@ -980,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
continue continue
} }
oldPeers := group.Peers
groupPeers := make(map[string]struct{}) groupPeers := make(map[string]struct{})
for _, pid := range group.Peers { for _, pid := range group.Peers {
groupPeers[pid] = struct{}{} groupPeers[pid] = struct{}{}
@@ -993,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
for pid := range groupPeers { for pid := range groupPeers {
group.Peers = append(group.Peers, pid) group.Peers = append(group.Peers, pid)
} }
groupUpdates[gid] = difference(group.Peers, oldPeers)
} }
return groupUpdates
} }
// UserGroupsRemoveFromPeers removes groups from all peers of user // UserGroupsRemoveFromPeers removes groups from all peers of user
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
for _, gid := range groups { for _, gid := range groups {
group, ok := a.Groups[gid] group, ok := a.Groups[gid]
if !ok || group.Name == "All" { if !ok || group.Name == "All" {
continue continue
} }
oldPeers := group.Peers
update := make([]string, 0, len(group.Peers)) update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers { for _, pid := range group.Peers {
peer, ok := a.Peers[pid] peer, ok := a.Peers[pid]
@@ -1014,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
} }
} }
group.Peers = update group.Peers = update
groupUpdates[gid] = difference(oldPeers, group.Peers)
} }
return groupUpdates
} }
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
@@ -1176,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err return nil, err
} }
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, fmt.Errorf("groups propagation failed: %w", err)
}
updatedAccount := account.UpdateSettings(newSettings) updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
@@ -1186,21 +1206,39 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return updatedAccount, nil return updatedAccount, nil
} }
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled if newSettings.GroupsPropagationEnabled {
if !newSettings.PeerInactivityExpirationEnabled { am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
event = activity.AccountPeerInactivityExpirationDisabled // Todo: retroactively add user groups to all peers
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { return nil
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) }
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
} else {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
} }
return nil return nil
@@ -1435,7 +1473,7 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager // addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -2029,7 +2067,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user: %w", err) return fmt.Errorf("error getting user: %w", err)
} }
groups, err := transaction.GetAccountGroups(ctx, accountID) groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -2059,7 +2097,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, accountID) groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -2083,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err) return fmt.Errorf("error incrementing network serial: %w", err)
} }
} }
@@ -2101,7 +2139,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else { } else {
@@ -2114,7 +2152,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else { } else {
@@ -2127,14 +2165,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
if err != nil { if err != nil {
return fmt.Errorf("error getting account: %w", err) return err
} }
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
} }
@@ -2290,12 +2333,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, status.NewGetAccountError(err)
} }
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
} }
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
@@ -2314,12 +2357,12 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return status.NewGetAccountError(err)
} }
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
} }
return nil return nil
@@ -2335,6 +2378,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlock := am.Store.AcquireReadLockByUID(ctx, accountID) unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock() defer unlock()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer unlockPeer()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
@@ -2398,12 +2444,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(ctx, accountID) am.updateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
return
}
am.updateAccountPeers(ctx, updatedAccount)
} }
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {

View File

@@ -6,13 +6,17 @@ import (
b64 "encoding/base64" b64 "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net" "net"
"os"
"reflect" "reflect"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -1038,7 +1042,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
} }
b.Run("public without account ID", func(b *testing.B) { b.Run("public without account ID", func(b *testing.B) {
//b.ResetTimer() // b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims) _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil { if err != nil {
@@ -1048,7 +1052,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
}) })
b.Run("private without account ID", func(b *testing.B) { b.Run("private without account ID", func(b *testing.B) {
//b.ResetTimer() // b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil { if err != nil {
@@ -1059,7 +1063,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
b.Run("private with account ID", func(b *testing.B) { b.Run("private with account ID", func(b *testing.B) {
claims.AccountId = id claims.AccountId = id
//b.ResetTimer() // b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil { if err != nil {
@@ -1238,8 +1242,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
return return
} }
policy := Policy{ _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1250,8 +1253,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1320,19 +1322,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
policy := Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
go func() { go func() {
@@ -1345,7 +1334,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
} }
}() }()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
})
if err != nil {
t.Errorf("delete default rule: %v", err) t.Errorf("delete default rule: %v", err)
return return
} }
@@ -1366,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return return
} }
policy := Policy{ _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1377,9 +1378,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
if err != nil {
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("save policy: %v", err) t.Errorf("save policy: %v", err)
return return
} }
@@ -1413,13 +1413,20 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{ err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
require.NoError(t, err, "failed to save group")
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
} }
policy := Policy{ policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1430,14 +1437,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
if err != nil {
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("save policy: %v", err) t.Errorf("save policy: %v", err)
return return
} }
@@ -1460,7 +1461,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return return
} }
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
t.Errorf("delete group: %v", err) t.Errorf("delete group: %v", err)
return return
} }
@@ -2714,7 +2715,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2734,7 +2735,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2773,7 +2774,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims) err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")
@@ -2985,3 +2986,218 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage)
t.Error("Timed out waiting for update message") t.Error("Timed out waiting for update message")
} }
} }
func BenchmarkSyncAndMarkPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 1, 3, 3, 10},
{"Medium", 500, 100, 7, 13, 10, 60},
{"Large", 5000, 200, 65, 80, 60, 170},
{"Small single", 50, 10, 1, 3, 3, 60},
{"Medium single", 500, 10, 7, 13, 10, 26},
{"Large 5", 5000, 15, 65, 80, 60, 170},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 102, 110, 102, 120},
{"Medium", 500, 100, 105, 140, 105, 170},
{"Large", 5000, 200, 160, 200, 160, 270},
{"Small single", 50, 10, 102, 110, 102, 120},
{"Medium single", 500, 10, 105, 140, 105, 170},
{"Large 5", 5000, 15, 160, 200, 160, 270},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
WireGuardPubKey: account.Peers["peer-1"].Key,
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
UserID: "regular_user",
SetupKey: "",
ConnectionIP: net.IP{1, 1, 1, 1},
})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
func BenchmarkLoginPeer_NewPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 107, 120, 107, 140},
{"Medium", 500, 100, 105, 140, 105, 170},
{"Large", 5000, 200, 180, 220, 180, 340},
{"Small single", 50, 10, 107, 120, 105, 140},
{"Medium single", 500, 10, 105, 140, 105, 170},
{"Large 5", 5000, 15, 180, 220, 180, 340},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
UserID: "regular_user",
SetupKey: "",
ConnectionIP: net.IP{1, 1, 1, 1},
})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}

View File

@@ -148,6 +148,9 @@ const (
AccountPeerInactivityExpirationDurationUpdated Activity = 67 AccountPeerInactivityExpirationDurationUpdated Activity = 67
SetupKeyDeleted Activity = 68 SetupKeyDeleted Activity = 68
UserGroupPropagationEnabled Activity = 69
UserGroupPropagationDisabled Activity = 70
) )
var activityMap = map[Activity]Code{ var activityMap = map[Activity]Code{
@@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"slices"
"strconv" "strconv"
"sync" "sync"
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
@@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
if dnsSettingsToSave == nil { if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
} }
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) if err != nil {
if err != nil {
return err
}
}
oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
for _, id := range addedGroups { if user.AccountID != accountID {
group := account.GetGroup(id) return status.NewUserNotPartOfAccountError()
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
} }
for _, id := range removedGroups { if !user.HasAdminPower() {
group := account.GetGroup(id) return status.NewAdminPermissionError()
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
} }
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { var updateAccountPeers bool
am.updateAccountPeers(ctx, account) var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err
}
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return err
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil
}
for _, groupID := range addedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
})
}
for _, groupID := range removedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
})
}
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {
return nil
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
if err != nil {
return err
}
return validateGroups(settings.DisabledManagementGroups, groups)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ protoUpdate := &proto.DNSConfig{

View File

@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// It is recommended to call it with locking FileStore.mux // It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error { func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now() start := time.Now()
err := util.WriteJson(file, s) err := util.WriteJson(context.Background(), file, s)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -6,11 +6,12 @@ import (
"fmt" "fmt"
"slices" "slices"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@@ -27,18 +28,17 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups // CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users") return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
} }
return nil return nil
@@ -49,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -58,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountGroups(ctx, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@@ -77,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// SaveGroups adds new groups to the account. // SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock. // Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var eventsToStore []func() var eventsToStore []func()
var groupsToSave []*nbgroup.Group
var updateAccountPeers bool
for _, newGroup := range newGroups { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { groupIDs := make([]string, 0, len(groups))
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) for _, newGroup := range groups {
} if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
} }
// Avoid duplicate groups only for the API issued groups. newGroup.AccountID = accountID
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. groupsToSave = append(groupsToSave, newGroup)
if existingGroup != nil { groupIDs = append(groupIDs, newGroup.ID)
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String() events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
} }
for _, peerID := range newGroup.Peers { updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if account.Peers[peerID] == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return err
}
} }
oldGroup := account.Groups[newGroup.ID] if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
account.Groups[newGroup.ID] = newGroup return err
}
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
eventsToStore = append(eventsToStore, events...) })
} if err != nil {
newGroupIDs := make([]string, 0, len(newGroups))
for _, newGroup := range newGroups {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, newGroupIDs) {
am.updateAccountPeers(ctx, account)
}
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
storeEvent() storeEvent()
} }
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// prepareGroupEvents prepares a list of event functions to be stored. // prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func() var eventsToStore []func()
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
if oldGroup != nil { oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
@@ -159,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range addedPeers { modifiedPeers := slices.Concat(addedPeers, removedPeers)
peer := account.Peers[p] peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
}
for _, peerID := range addedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, meta := map[string]any{
map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID,
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), }
}) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
}) })
} }
for _, p := range removedPeers { for _, peerID := range removedPeers {
peer := account.Peers[p] peer, ok := peers[peerID]
if peer == nil { if !ok {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, meta := map[string]any{
map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID,
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), }
}) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
}) })
} }
@@ -210,42 +210,10 @@ func difference(a, b []string) []string {
} }
// DeleteGroup object of the peers. // DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return nil
}
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
return nil
} }
// DeleteGroups deletes groups from an account. // DeleteGroups deletes groups from an account.
@@ -254,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
// //
// If an error occurs while deleting a group, the function skips it and continues deleting other groups. // If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end. // Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var allErrors error var allErrors error
var groupIDsToDelete []string
var deletedGroups []*nbgroup.Group
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
group, ok := account.Groups[groupID] group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
if !ok { if err != nil {
continue allErrors = errors.Join(allErrors, err)
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
} }
if err := validateDeleteGroup(account, group, userId); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) return err
continue
} }
delete(account.Groups, groupID) return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
deletedGroups = append(deletedGroups, group) })
} if err != nil {
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
for _, g := range deletedGroups { for _, group := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
} }
return allErrors return allErrors
} }
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if updateAccountPeers {
if !ok { am.updateAccountPeers(ctx, accountID)
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
} }
return nil return nil
@@ -351,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if updateAccountPeers {
if !ok { am.updateAccountPeers(ctx, accountID)
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
}
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
} }
return nil return nil
} }
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { // validateNewGroup validates the new group for existence and required fields.
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
}
}
// Prevent duplicate groups for API-issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
return nil
}
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration { if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID] executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
if executingUser == nil { if err != nil {
return status.Errorf(status.NotFound, "user not found") return err
} }
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
} }
} }
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name} return &GroupLinkError{"name server groups", linkedDns.Name}
} }
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name} return &GroupLinkError{"policy", linkedPolicy.Name}
} }
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name} return &GroupLinkError{"setup key", linkedSetupKey.Name}
} }
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id} return &GroupLinkError{"user", linkedUser.Id}
} }
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
if account.Settings.Extra != nil { settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { if err != nil {
return &GroupLinkError{"integrated validator", group.Name} return err
} }
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
} }
return nil return nil
} }
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes { for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r return true, r
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies { for _, policy := range policies {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@@ -446,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
} }
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups { for _, dns := range nameServerGroups {
for _, g := range dns.Groups { for _, g := range dns.Groups {
if g == groupID { if g == groupID {
@@ -454,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
} }
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys { for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) { if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey return true, setupKey
@@ -468,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
} }
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users { for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) { if slices.Contains(user.AutoGroups, groupID) {
return true, user return true, user
@@ -477,8 +537,36 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
return false, nil return false, nil
} }
// anyGroupHasPeers checks if any of the given groups in the account have peers. // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool { func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() { if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true return true
@@ -487,21 +575,18 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
return false return false
} }
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { // anyGroupHasPeers checks if any of the given groups in the account have peers.
for _, groupID := range groupIDs { func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
return true if err != nil {
} return false, err
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { }
return true
} for _, group := range groups {
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { if group.HasPeers() {
return true return true, nil
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
} }
} }
return false return false, nil
} }

View File

@@ -49,3 +49,35 @@ func (g *Group) Copy() *Group {
func (g *Group) HasPeers() bool { func (g *Group) HasPeers() bool {
return len(g.Peers) > 0 return len(g.Peers) > 0
} }
// IsGroupAll checks if the group is a default "All" group.
func (g *Group) IsGroupAll() bool {
return g.Name == "All"
}
// AddPeer adds peerID to Peers if not present, returning true if added.
func (g *Group) AddPeer(peerID string) bool {
if peerID == "" {
return false
}
for _, itemID := range g.Peers {
if itemID == peerID {
return false
}
}
g.Peers = append(g.Peers, peerID)
return true
}
// RemovePeer removes peerID from Peers if present, returning true if removed.
func (g *Group) RemovePeer(peerID string) bool {
for i, itemID := range g.Peers {
if itemID == peerID {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
return true
}
}
return false
}

View File

@@ -0,0 +1,90 @@
package group
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddPeer(t *testing.T) {
t.Run("add new peer to empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to non-empty slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add duplicate peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer1"
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("add empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}
func TestRemovePeer(t *testing.T) {
t.Run("remove existing peer from slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
peerID := "peer2"
assert.True(t, group.RemovePeer(peerID))
assert.NotContains(t, group.Peers, peerID)
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
})
t.Run("remove peer from nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Nil(t, group.Peers)
})
t.Run("remove non-existent peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from single-item slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1"}}
peerID := "peer1"
assert.True(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
assert.NotContains(t, group.Peers, peerID)
})
t.Run("remove empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}

View File

@@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
{ {
name: "delete non-existent group", name: "delete non-existent group",
groupIDs: []string{"non-existent-group"}, groupIDs: []string{"non-existent-group"},
expectedDeleted: []string{"non-existent-group"}, expectedReasons: []string{"group: non-existent-group not found"},
}, },
{ {
name: "delete multiple groups with mixed results", name: "delete multiple groups with mixed results",
@@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}) })
// adding a group to policy // adding a group to policy
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
}, false) })
assert.NoError(t, err) assert.NoError(t, err)
// Saving a group linked to policy should update account peers and send peer update // Saving a group linked to policy should update account peers and send peer update

View File

@@ -180,6 +180,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
return mapError(ctx, err) return mapError(ctx, err)
} }
@@ -207,6 +208,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// handleUpdates sends updates to the connected peer until the updates channel is closed. // handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
for { for {
select { select {
// condition when there are some updates // condition when there are some updates
@@ -260,10 +262,15 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
unlock := s.acquirePeerLockByUID(ctx, peer.Key) unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock() defer unlock()
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}
s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.secretsManager.CancelRefresh(peer.ID) s.secretsManager.CancelRefresh(peer.ID)
s.ephemeralManager.OnPeerDisconnected(ctx, peer) s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
} }
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {

View File

@@ -439,17 +439,13 @@ components:
example: 5 example: 5
required: required:
- accessible_peers_count - accessible_peers_count
SetupKey: SetupKeyBase:
type: object type: object
properties: properties:
id: id:
description: Setup Key ID description: Setup Key ID
type: string type: string
example: 2531583362 example: 2531583362
key:
description: Setup Key value
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
name: name:
description: Setup key name identifier description: Setup key name identifier
type: string type: string
@@ -518,22 +514,31 @@ components:
- updated_at - updated_at
- usage_limit - usage_limit
- ephemeral - ephemeral
SetupKeyClear:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as plain text
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
required:
- key
SetupKey:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as secret
type: string
example: A6160****
required:
- key
SetupKeyRequest: SetupKeyRequest:
type: object type: object
properties: properties:
name:
description: Setup Key name
type: string
example: Default key
type:
description: Setup key type, one-off for single time usage and reusable
type: string
example: reusable
expires_in:
description: Expiration time in seconds, 0 will mean the key never expires
type: integer
minimum: 0
example: 86400
revoked: revoked:
description: Setup key revocation status description: Setup key revocation status
type: boolean type: boolean
@@ -544,21 +549,9 @@ components:
items: items:
type: string type: string
example: "ch8i4ug6lnn4g9hqv7m0" example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer
example: 0
ephemeral:
description: Indicate that the peer will be ephemeral or not
type: boolean
example: true
required: required:
- name
- type
- expires_in
- revoked - revoked
- auto_groups - auto_groups
- usage_limit
CreateSetupKeyRequest: CreateSetupKeyRequest:
type: object type: object
properties: properties:
@@ -1943,7 +1936,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/SetupKey' $ref: '#/components/schemas/SetupKeyClear'
'400': '400':
"$ref": "#/components/responses/bad_request" "$ref": "#/components/responses/bad_request"
'401': '401':

View File

@@ -1062,7 +1062,94 @@ type SetupKey struct {
// Id Setup Key ID // Id Setup Key ID
Id string `json:"id"` Id string `json:"id"`
// Key Setup Key value // Key Setup Key as secret
Key string `json:"key"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyBase defines model for SetupKeyBase.
type SetupKeyBase struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyClear defines model for SetupKeyClear.
type SetupKeyClear struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// Key Setup Key as plain text
Key string `json:"key"` Key string `json:"key"`
// LastUsed Setup key last usage date // LastUsed Setup key last usage date
@@ -1098,23 +1185,8 @@ type SetupKeyRequest struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key // AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral *bool `json:"ephemeral,omitempty"`
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
ExpiresIn int `json:"expires_in"`
// Name Setup Key name
Name string `json:"name"`
// Revoked Setup key revocation status // Revoked Setup key revocation status
Revoked bool `json:"revoked"` Revoked bool `json:"revoked"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
} }
// User defines model for User. // User defines model for User.

View File

@@ -0,0 +1,9 @@
package configs
// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Audience string
UserIDClaim string
KeysLocation string
}

View File

@@ -12,6 +12,16 @@ import (
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
"github.com/netbirdio/netbird/management/server/http/handlers/events"
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/handlers/users"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -20,27 +30,15 @@ import (
const apiPrefix = "/api" const apiPrefix = "/api"
// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Audience string
UserIDClaim string
KeysLocation string
}
type apiHandler struct { type apiHandler struct {
Router *mux.Router Router *mux.Router
AccountManager s.AccountManager AccountManager s.AccountManager
geolocationManager *geolocation.Geolocation geolocationManager *geolocation.Geolocation
AuthCfg AuthCfg AuthCfg configs.AuthCfg
}
// EmptyObject is an empty struct used to return empty JSON object
type emptyObject struct {
} }
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor( claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@@ -86,122 +84,15 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
return nil, fmt.Errorf("register integrations endpoints: %w", err) return nil, fmt.Errorf("register integrations endpoints: %w", err)
} }
api.addAccountsEndpoint() accounts.AddEndpoints(api.AccountManager, authCfg, router)
api.addPeersEndpoint() peers.AddEndpoints(api.AccountManager, authCfg, router)
api.addUsersEndpoint() users.AddEndpoints(api.AccountManager, authCfg, router)
api.addUsersTokensEndpoint() setup_keys.AddEndpoints(api.AccountManager, authCfg, router)
api.addSetupKeysEndpoint() policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router)
api.addPoliciesEndpoint() groups.AddEndpoints(api.AccountManager, authCfg, router)
api.addGroupsEndpoint() routes.AddEndpoints(api.AccountManager, authCfg, router)
api.addRoutesEndpoint() dns.AddEndpoints(api.AccountManager, authCfg, router)
api.addDNSNameserversEndpoint() events.AddEndpoints(api.AccountManager, authCfg, router)
api.addDNSSettingEndpoint()
api.addEventsEndpoint()
api.addPostureCheckEndpoint()
api.addLocationsEndpoint()
return rootRouter, nil return rootRouter, nil
} }
func (apiHandler *apiHandler) addAccountsEndpoint() {
accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.DeleteAccount).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addPeersEndpoint() {
peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersEndpoint() {
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addSetupKeysEndpoint() {
keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addPoliciesEndpoint() {
policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addGroupsEndpoint() {
groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addRoutesEndpoint() {
routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addDNSNameserversEndpoint() {
nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addDNSSettingEndpoint() {
dnsSettingsHandler := NewDNSSettingsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS")
}
func (apiHandler *apiHandler) addEventsEndpoint() {
eventsHandler := NewEventsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/events", eventsHandler.GetAllEvents).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addPostureCheckEndpoint() {
postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.GetPostureCheck).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.DeletePostureCheck).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addLocationsEndpoint() {
locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS")
}

View File

@@ -1,4 +1,4 @@
package http package accounts
import ( import (
"encoding/json" "encoding/json"
@@ -10,20 +10,28 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// AccountsHandler is a handler that handles the server.Account HTTP endpoints // handler is a handler that handles the server.Account HTTP endpoints
type AccountsHandler struct { type handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewAccountsHandler creates a new AccountsHandler HTTP handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *AccountsHandler { accountsHandler := newHandler(accountManager, authCfg)
return &AccountsHandler{ router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@@ -32,8 +40,8 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
} }
} }
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. // getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@@ -51,8 +59,8 @@ func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
} }
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@@ -111,8 +119,8 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// DeleteAccount is a HTTP DELETE handler to delete an account // deleteAccount is a HTTP DELETE handler to delete an account
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
vars := mux.Vars(r) vars := mux.Vars(r)
targetAccountID := vars["accountId"] targetAccountID := vars["accountId"]
@@ -127,7 +135,7 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func toAccountResponse(accountID string, settings *server.Settings) *api.Account { func toAccountResponse(accountID string, settings *server.Settings) *api.Account {

View File

@@ -1,4 +1,4 @@
package http package accounts
import ( import (
"bytes" "bytes"
@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { func initAccountsTestData(account *server.Account, admin *server.User) *handler {
return &AccountsHandler{ return &handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return account.Id, admin.Id, nil return account.Id, admin.Id, nil
@@ -89,7 +89,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
requestBody io.Reader requestBody io.Reader
}{ }{
{ {
name: "GetAllAccounts OK", name: "getAllAccounts OK",
expectedBody: true, expectedBody: true,
requestType: http.MethodGet, requestType: http.MethodGet,
requestPath: "/api/accounts", requestPath: "/api/accounts",
@@ -189,8 +189,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET") router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT") router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -1,26 +1,39 @@
package http package dns
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
// DNSSettingsHandler is a handler that returns the DNS settings of the account // dnsSettingsHandler is a handler that returns the DNS settings of the account
type DNSSettingsHandler struct { type dnsSettingsHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewDNSSettingsHandler returns a new instance of DNSSettingsHandler handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettingsHandler { addDNSSettingEndpoint(accountManager, authCfg, router)
return &DNSSettingsHandler{ addDNSNameserversEndpoint(accountManager, authCfg, router)
}
func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg)
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
}
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler {
return &dnsSettingsHandler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@@ -29,8 +42,8 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
} }
} }
// GetDNSSettings returns the DNS settings for the account // getDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@@ -52,8 +65,8 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
util.WriteJSONObject(r.Context(), w, apiDNSSettings) util.WriteJSONObject(r.Context(), w, apiDNSSettings)
} }
// UpdateDNSSettings handles update to DNS settings of an account // updateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {

View File

@@ -1,4 +1,4 @@
package http package dns
import ( import (
"bytes" "bytes"
@@ -40,8 +40,8 @@ var testingDNSSettingsAccount = &server.Account{
DNSSettings: baseExistingDNSSettings, DNSSettings: baseExistingDNSSettings,
} }
func initDNSSettingsTestData() *DNSSettingsHandler { func initDNSSettingsTestData() *dnsSettingsHandler {
return &DNSSettingsHandler{ return &dnsSettingsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
return &testingDNSSettingsAccount.DNSSettings, nil return &testingDNSSettingsAccount.DNSSettings, nil
@@ -120,8 +120,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/dns/settings", p.GetDNSSettings).Methods("GET") router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
router.HandleFunc("/api/dns/settings", p.UpdateDNSSettings).Methods("PUT") router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -1,4 +1,4 @@
package http package dns
import ( import (
"encoding/json" "encoding/json"
@@ -11,20 +11,30 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// NameserversHandler is the nameserver group handler of the account // nameserversHandler is the nameserver group handler of the account
type NameserversHandler struct { type nameserversHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewNameserversHandler returns a new instance of NameserversHandler handler func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg) *NameserversHandler { nameserversHandler := newNameserversHandler(accountManager, authCfg)
return &NameserversHandler{ router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS")
}
// newNameserversHandler returns a new instance of nameserversHandler handler
func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler {
return &nameserversHandler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@@ -33,8 +43,8 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
} }
} }
// GetAllNameservers returns the list of nameserver groups for the account // getAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@@ -57,8 +67,8 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
util.WriteJSONObject(r.Context(), w, apiNameservers) util.WriteJSONObject(r.Context(), w, apiNameservers)
} }
// CreateNameserverGroup handles nameserver group creation request // createNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@@ -90,8 +100,8 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID // updateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@@ -141,8 +151,8 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// DeleteNameserverGroup handles nameserver group deletion request // deleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@@ -162,11 +172,11 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
// GetNameserverGroup handles a nameserver group Get request identified by ID // getNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {

View File

@@ -1,4 +1,4 @@
package http package dns
import ( import (
"bytes" "bytes"
@@ -50,8 +50,8 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
Enabled: true, Enabled: true,
} }
func initNameserversTestData() *NameserversHandler { func initNameserversTestData() *nameserversHandler {
return &NameserversHandler{ return &nameserversHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if nsGroupID == existingNSGroupID { if nsGroupID == existingNSGroupID {
@@ -206,10 +206,10 @@ func TestNameserversHandlers(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET") router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST") router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE") router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT") router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@@ -1,28 +1,35 @@
package http package events
import ( import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
// EventsHandler HTTP handler // handler HTTP handler
type EventsHandler struct { type handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewEventsHandler creates a new EventsHandler HTTP handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *EventsHandler { eventsHandler := newHandler(accountManager, authCfg)
return &EventsHandler{ router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
}
// newHandler creates a new events handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@@ -31,8 +38,8 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
} }
} }
// GetAllEvents list of the given account // getAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@@ -60,7 +67,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, events) util.WriteJSONObject(r.Context(), w, events)
} }
func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
// build email, name maps based on users // build email, name maps based on users
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
if err != nil { if err != nil {

View File

@@ -1,4 +1,4 @@
package http package events
import ( import (
"context" "context"
@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { func initEventsTestData(account string, events ...*activity.Event) *handler {
return &EventsHandler{ return &handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
if accountID == account { if accountID == account {
@@ -183,7 +183,7 @@ func TestEvents_GetEvents(t *testing.T) {
requestBody io.Reader requestBody io.Reader
}{ }{
{ {
name: "GetAllEvents OK", name: "getAllEvents OK",
expectedBody: true, expectedBody: true,
requestType: http.MethodGet, requestType: http.MethodGet,
requestPath: "/api/events/", requestPath: "/api/events/",
@@ -201,7 +201,7 @@ func TestEvents_GetEvents(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/events/", handler.GetAllEvents).Methods("GET") router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

Some files were not shown because too many files have changed in this diff Show More