Compare commits

...

45 Commits

Author SHA1 Message Date
Pascal Fischer
933cf1c84d extract getting engine 2024-10-01 15:44:32 +02:00
Pascal Fischer
5eb936b49e add engine stop to run in foreground mode 2024-10-01 15:39:38 +02:00
Simen
24c0aaa745 Install sh alpine fixes (#2678)
* Made changes to the peer install script that makes it work on alpine linux without changes

* fix small oversight with doas fix

* use try catch approach when curling binaries
2024-10-01 13:32:58 +02:00
pascal-fischer
16179db599 [management] Propagate metrics (#2667) 2024-09-30 22:18:10 +02:00
Maycon Santos
e27f85b317 Update docker creds (#2677) 2024-09-30 20:07:21 +02:00
Gianluca Boiano
2fd60b2cb4 Specify goreleaser version and update to 2 (#2673) 2024-09-30 16:43:34 +02:00
Zoltan Papp
3dca6099d4 Fix ebpf close function (#2672) 2024-09-30 10:34:57 +02:00
pascal-fischer
cfbcf507fb propagate meter (#2668) 2024-09-29 20:23:34 +02:00
pascal-fischer
52ae693c9e [signal] add context to signal-dispatcher (#2662) 2024-09-29 00:22:47 +02:00
adasauce
58ff7ab797 [management] improve zitadel idp error response detail by decoding errors (#2634)
* [management] improve zitadel idp error response detail by decoding errors

* [management] extend readZitadelError to be used for requestJWTToken

more generically parse the error returned by zitadel.

* fix lint

---------

Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-27 22:21:34 +03:00
Bethuel Mmbaga
acb73bd64a [management] Remove redundant get account calls in GetAccountFromToken (#2615)
* refactor access control middleware and user access by JWT groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor jwt groups extractor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor handlers to get account when necessary

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountWithAuthorizationClaims

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* revert handles change

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove GetUserByID from account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountWithAuthorizationClaims to return account id

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor handlers to use GetAccountIDFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove locks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add GetGroupByName from store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add GetGroupByID from store and refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor retrieval of policy and posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor user permissions and retrieves PAT

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor route, setupkey, nameserver and dns to get record(s) from store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix add missing policy source posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add store lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add get account

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-27 17:10:50 +03:00
Zoltan Papp
4ebf6e1c4c [client] Close the remote conn in proxy (#2626)
Port the conn close call to eBPF proxy
2024-09-25 18:50:10 +02:00
pascal-fischer
1e4a0f77e2 Add get DB method to store (#2650) 2024-09-25 18:22:27 +02:00
Viktor Liu
b51d75204b [client] Anonymize relay address in status peers view (#2640) 2024-09-24 20:58:18 +02:00
Viktor Liu
e7d52c8c95 [client] Fix error count formatting (#2641) 2024-09-24 20:57:56 +02:00
Viktor Liu
ab82302c95 [client] Remove usage of custom dialer for localhost (#2639)
* Downgrade error log level for network monitor warnings

* Do not use custom dialer for localhost
2024-09-24 12:29:15 +02:00
pascal-fischer
d47be154ea [misc] Fix ip range posture check example (#2628) 2024-09-23 10:02:03 +02:00
Bethuel Mmbaga
35c892aea3 [management] Restrict accessible peers to user-owned peers for non-admins (#2618)
* Restrict accessible peers to user-owned peers for non-admin users

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add service user test

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* reuse account from token

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* return error when peer not found

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-20 12:36:58 +03:00
Zoltan Papp
fc4b37f7bc Exit from processConnResults after all tries (#2621)
* Exit from processConnResults after all tries

If all server is unavailable then the server picker never return
because we never close the result channel.
Count the number of the results and exit when we reached the
expected size
2024-09-19 13:49:28 +02:00
Zoltan Papp
6f0fd1d1b3 - Increase queue size and drop the overflowed messages (#2617)
- Explicit close the net.Conn in user space wgProxy when close the wgProxy
- Add extra logs
2024-09-19 13:49:09 +02:00
Zoltan Papp
28cbb4b70f [client] Cancel the context of wg watcher when the go routine exit (#2612) 2024-09-17 12:10:17 +02:00
Zoltan Papp
1104c9c048 [client] Fix race condition while read/write conn status in peer conn (#2607) 2024-09-17 11:15:14 +02:00
Maycon Santos
5bc601111d [relay] Add health check attempt threshold (#2609)
* Add health check attempt threshold for receiver

* Add health check attempt threshold for sender
2024-09-17 10:04:17 +02:00
Zoltan Papp
b74951f29e [client] Enforce permissions on Win (#2568)
Enforce folder permission on Windows, giving only administrators and system access to the NetBird folder.
2024-09-16 22:42:37 +02:00
Zoltan Papp
97e10e440c Fix leaked server connections (#2596)
Fix leaked server connections

close unused connections in the client lib
close deprecated connection in the server lib
The Server Picker is reusable in the guard if we want in the future. So we can support the server address changes.

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Add logging

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-09-16 16:11:10 +02:00
pascal-fischer
6c50b0c84b [management] Add transaction to addPeer (#2469)
This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
2024-09-16 15:47:03 +02:00
pascal-fischer
730dd1733e [signal] Fix signal active peers metrics (#2591) 2024-09-15 16:46:55 +02:00
Bethuel Mmbaga
82739e2832 [management] fix legacy decrypting of empty values (#2595)
* allow legacy decrypting on empty values

* validate source size and padding limits

* added tests

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-09-15 16:22:46 +02:00
Maycon Santos
fa7767e612 Fix get management and signal state race condition (#2570)
* Fix get management and signal state race condition

* fix get full status lock
2024-09-15 16:07:26 +02:00
benniekiss
f1171198de [management] Add command flag to set metrics port for signal and relay service, and update management port (#2599)
* add flags to customize metrics port for relay and signal

* change management default metrics port to match other services
2024-09-14 10:34:32 +02:00
Zoltan Papp
9e041b7f82 Fix blocked net.Conn Close call (#2600) 2024-09-14 10:27:37 +02:00
Zoltan Papp
b4c8cf0a67 Change heartbeat timeout (#2598) 2024-09-14 10:12:54 +02:00
Carlos Hernandez
1ef51a4ffa [client] Ensure engine is stopped before starting it back (#2565)
Before starting a new instance of the engine, check if it is nil and stop the current instance
2024-09-13 16:46:59 +02:00
Maycon Santos
f6d57e7a96 [misc] Support configurable max log size with var NB_LOG_MAX_SIZE_MB (#2592)
* Support configurable max log size with var NB_LOG_MAX_SIZE_MB

* add better logs
2024-09-12 19:56:55 +02:00
Zoltan Papp
ab892b8cf9 Fix wg handshake checking (#2590)
* Fix wg handshake checking

* Ensure in the initial handshake reading

* Change the handshake period
2024-09-12 19:18:02 +02:00
Gianluca Boiano
33c9b2d989 fix: install.sh: avoid call of netbird executable after rpm installation (#2589) 2024-09-12 17:32:47 +02:00
Bethuel Mmbaga
170e842422 [management] Add accessible peers endpoint (#2579)
* move accessible peer to separate endpoint in api doc

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add endpoint to get accessible peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update management/server/http/api/openapi.yml

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

* Update management/server/http/api/openapi.yml

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

* Update management/server/http/peers_handler.go

Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com>
2024-09-12 16:19:27 +03:00
Maycon Santos
4c130a0291 Update Go version to 1.23 (#2588) 2024-09-12 13:46:28 +02:00
Maycon Santos
afb9673bc4 [misc] Update core github actions (#2584) 2024-09-11 21:49:05 +02:00
Bethuel Mmbaga
cf6210a6f4 [management] Add GCM encryption and migrate legacy encrypted events (#2569)
* Add AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* migrate legacy encrypted data to AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor and use transaction when migrating data

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add events migration tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip migrating record on error

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preallocate capacity for nonce to avoid allocations in Seal

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-11 20:09:57 +03:00
Maycon Santos
c59a39d27d Update service package version (#2582) 2024-09-11 19:05:10 +02:00
Maycon Santos
47adb976f8 Remove pre-release step from workflow (#2583) 2024-09-11 18:59:19 +02:00
Zoltan Papp
9cfc8f8aa4 [relay] change log levels (#2580) 2024-09-11 18:36:19 +02:00
Viktor Liu
2d1bf3982d [relay] Improve relay messages (#2574)
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2024-09-11 16:20:30 +02:00
Viktor Liu
50ebbe482e [client] Don't overwrite allowed IPs when updating the wg peer's endpoint address (#2578)
This will fix broken routes on routing clients when upgrading/downgrading from/to relayed connections.
2024-09-11 16:05:13 +02:00
143 changed files with 5488 additions and 2183 deletions

View File

@@ -18,14 +18,14 @@ jobs:
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@@ -19,13 +19,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -49,18 +49,18 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -124,4 +124,4 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
id: go
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Download wintun
uses: carlosperate/download-file-action@v2

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
@@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
@@ -49,4 +49,4 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
version: latest
args: --timeout=12m
args: --timeout=12m

View File

@@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: run install script
env:

View File

@@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v3
uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
@@ -62,4 +62,4 @@ jobs:
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
CGO_ENABLED: 0
CGO_ENABLED: 0

View File

@@ -3,15 +3,14 @@ name: Release
on:
push:
tags:
- 'v*'
- "v*"
branches:
- main
pull_request:
env:
SIGN_PIPE_VER: "v0.0.14"
GORELEASER_VER: "v1.14.1"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@@ -34,20 +33,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.21"
go-version: "1.23"
cache: false
-
name: Cache Go modules
uses: actions/cache@v3
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -55,24 +51,19 @@ jobs:
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-releaser-
-
name: Install modules
- name: Install modules
run: go mod tidy
-
name: check git status
- name: check git status
run: git --no-pager diff --exit-code
-
name: Set up QEMU
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
-
name: Set up Docker Buildx
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
-
name: Login to Docker hub
- name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
with:
username: netbirdio
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
@@ -85,36 +76,32 @@ jobs:
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --rm-dist ${{ env.flags }}
args: release --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 3
-
name: upload linux packages
uses: actions/upload-artifact@v3
- name: upload linux packages
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
-
name: upload windows packages
uses: actions/upload-artifact@v3
- name: upload windows packages
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
-
name: upload macos packages
uses: actions/upload-artifact@v3
- name: upload macos packages
uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -133,19 +120,19 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21"
go-version: "1.23"
cache: false
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: |
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
@@ -169,14 +156,14 @@ jobs:
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: release-ui
path: dist/
@@ -187,20 +174,17 @@ jobs:
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v3
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.21"
go-version: "1.23"
cache: false
-
name: Cache Go modules
uses: actions/cache@v3
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -208,24 +192,20 @@ jobs:
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin-
-
name: Install modules
- name: Install modules
run: go mod tidy
-
name: check git status
- name: check git status
run: git --no-pager diff --exit-code
-
name: Run GoReleaser
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/
@@ -233,7 +213,7 @@ jobs:
trigger_signer:
runs-on: ubuntu-latest
needs: [release,release_ui,release_ui_darwin]
needs: [release, release_ui, release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger binaries sign pipelines

View File

@@ -50,12 +50,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.21.x"
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +63,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -219,10 +219,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v3
- name: handle insisting image # remove after release
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@@ -259,9 +256,6 @@ jobs:
docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: handle insisting image gen CockroachDB # remove after release
run: docker pull netbirdio/relay:latest || docker pull netbirdio/signal:latest && docker tag netbirdio/signal:latest netbirdio/relay:latest
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:

View File

@@ -1,3 +1,5 @@
version: 2
project_name: netbird
builds:
- id: netbird
@@ -22,7 +24,7 @@ builds:
goarch: 386
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
@@ -42,19 +44,19 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
- id: netbird-mgmt
dir: management
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-mgmt
goos:
- linux
@@ -64,7 +66,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-signal
dir: signal
@@ -78,7 +80,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-relay
dir: relay
@@ -92,7 +94,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
archives:
- builds:
@@ -100,7 +102,6 @@ archives:
- netbird-static
nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
@@ -416,10 +417,9 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-amd64
brews:
-
ids:
- ids:
- default
tap:
repository:
owner: netbirdio
name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
@@ -436,7 +436,7 @@ brews:
uploads:
- name: debian
ids:
- netbird-deb
- netbird-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com

View File

@@ -1,3 +1,5 @@
version: 2
project_name: netbird-ui
builds:
- id: netbird-ui
@@ -11,7 +13,7 @@ builds:
- amd64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows
dir: client/ui
@@ -26,7 +28,7 @@ builds:
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
archives:
- id: linux-arch
@@ -39,7 +41,6 @@ archives:
- netbird-ui-windows
nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
@@ -77,7 +78,7 @@ nfpms:
uploads:
- name: debian
ids:
- netbird-ui-deb
- netbird-ui-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com

View File

@@ -1,3 +1,5 @@
version: 2
project_name: netbird-ui
builds:
- id: netbird-ui-darwin
@@ -17,7 +19,7 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}'
mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
@@ -28,4 +30,4 @@ archives:
checksum:
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
changelog:
skip: true
disable: true

View File

@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
**Goreleaser**
```shell
goreleaser --snapshot --rm-dist
goreleaser build --snapshot --clean
```
**golangci-lint**
```shell

View File

@@ -805,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}

View File

@@ -57,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
srv, err := sig.NewServer(otel.Meter(""))
srv, err := sig.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)

View File

@@ -8,8 +8,8 @@ import (
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))

View File

@@ -117,6 +117,11 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
@@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
err = WriteOutConfig(input.ConfigPath, cfg)
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
return update(input)
}

View File

@@ -158,6 +158,7 @@ func (c *ConnectClient) run(
}
defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error {
// if context cancelled we not start new backoff cycle
if c.isContextCancelled() {
@@ -267,6 +268,12 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
if c.engine != nil && c.engine.ctx.Err() != nil {
log.Info("Stopping Netbird Engine")
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
@@ -279,14 +286,25 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
if runningChan != nil && runningChanOpen {
runningChan <- nil
close(runningChan)
runningChanOpen = false
}
<-engineCtx.Done()
c.statusRecorder.ClientTeardown()
c.engineMutex.Lock()
engine := c.Engine()
if engine != nil {
err = engine.Stop()
}
c.engineMutex.Unlock()
if err != nil {
return err
}
backOff.Reset()
log.Info("stopped NetBird client")

View File

@@ -292,7 +292,7 @@ func (e *Engine) Start() error {
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
@@ -1115,10 +1115,7 @@ func (e *Engine) close() {
}
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
@@ -1360,12 +1357,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
}
func (e *Engine) restartEngine() {
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated")
e.clientCancel()
}
func (e *Engine) startNetworkMonitor() {
@@ -1387,6 +1388,7 @@ func (e *Engine) startNetworkMonitor() {
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
}
@@ -1396,7 +1398,7 @@ func (e *Engine) startNetworkMonitor() {
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
@@ -1421,6 +1423,20 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
return false, netip.Prefix{}, nil
}
func (e *Engine) stopDNSServer() {
err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates {
nsGroupStates[i].Enabled = false
nsGroupStates[i].Error = err
}
e.statusRecorder.UpdateDNSStates(nsGroupStates)
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {

View File

@@ -1056,7 +1056,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(otel.Meter(""))
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)

View File

@@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
log.Warnf("Network monitor: failed to close routing socket: %v", err)
}
}()
@@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
@@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
@@ -61,7 +61,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing routing message: %v", err)
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}

View File

@@ -89,8 +89,8 @@ type Conn struct {
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string, wgIP string)
statusRelay ConnStatus
statusICE ConnStatus
statusRelay *AtomicConnStatus
statusICE *AtomicConnStatus
currentConnPriority ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection
@@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
signaler: signaler,
relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(),
statusRelay: StatusDisconnected,
statusICE: StatusDisconnected,
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
}
@@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected {
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
}
} else {
if conn.statusICE == StatusDisconnected {
if conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
}
}
@@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready")
conn.statusICE = StatusConnected
conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo)
@@ -492,8 +492,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.currentConnPriority = connPriorityRelay
}
changed := conn.statusICE != newState && newState != StatusConnecting
conn.statusICE = newState
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)
select {
case conn.iCEDisconnected <- changed:
@@ -518,18 +518,22 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return
}
conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay = StatusConnected
conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
wgProxy := conn.wgProxyFactory.GetProxy()
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr
@@ -538,7 +542,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
if conn.currentConnPriority > connPriorityRelay {
if conn.statusICE == StatusConnected {
if conn.statusICE.Get() == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
return
}
@@ -559,8 +563,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
return
}
wgConfigWorkaround()
conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround()
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
@@ -594,8 +598,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
conn.wgProxyRelay = nil
}
changed := conn.statusRelay != StatusDisconnected
conn.statusRelay = StatusDisconnected
changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
select {
case conn.relayDisconnected <- changed:
@@ -661,8 +665,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
}
func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay = StatusDisconnected
conn.statusICE = StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
conn.statusICE.Set(StatusDisconnected)
peerState := State{
PubKey: conn.config.Key,
@@ -706,7 +710,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
}
func (conn *Conn) isRelayed() bool {
if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) {
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
return false
}
@@ -718,11 +722,11 @@ func (conn *Conn) isRelayed() bool {
}
func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
return StatusConnected
}
if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting {
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
return StatusConnecting
}
@@ -733,12 +737,12 @@ func (conn *Conn) isConnected() bool {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting {
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
return false
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay != StatusConnected {
if conn.statusRelay.Get() != StatusConnected {
return false
}
}
@@ -771,13 +775,12 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
}
conn.log.Debugf("setup ice turn connection")
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
err = wgProxy.CloseConn()
if err != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
}
return nil, nil, err
}

View File

@@ -1,6 +1,10 @@
package peer
import log "github.com/sirupsen/logrus"
import (
"sync/atomic"
log "github.com/sirupsen/logrus"
)
const (
// StatusConnected indicate the peer is in connected state
@@ -12,7 +16,34 @@ const (
)
// ConnStatus describe the status of a peer's connection
type ConnStatus int
type ConnStatus int32
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
func (s ConnStatus) String() string {
switch s {

View File

@@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -158,8 +158,13 @@ func TestConn_Status(t *testing.T) {
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
conn.statusICE = table.statusIce
conn.statusRelay = table.statusRelay
si := NewAtomicConnStatus()
si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")

View File

@@ -597,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
d.mux.Lock()
defer d.mux.Unlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -604,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetManagementState() ManagementState {
d.mux.Lock()
defer d.mux.Unlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -645,6 +649,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
d.mux.Lock()
defer d.mux.Unlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -654,6 +660,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.Lock()
defer d.mux.Unlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -684,6 +692,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock()
defer d.mux.Unlock()
return d.nsGroupStates
}
@@ -695,18 +705,19 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()
defer d.mux.Unlock()
fullStatus := FullStatus{
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
}
d.mux.Lock()
defer d.mux.Unlock()
fullStatus.LocalPeerState = d.localPeer
for _, status := range d.peers {
fullStatus.Peers = append(fullStatus.Peers, status)
}

View File

@@ -14,7 +14,7 @@ import (
)
var (
wgHandshakePeriod = 2 * time.Minute
wgHandshakePeriod = 3 * time.Minute
wgHandshakeOvertime = 30 * time.Second
)
@@ -109,10 +109,10 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
}
ctx, ctxCancel := context.WithCancel(ctx)
go w.wgStateCheck(ctx)
w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel
w.wgStateCheck(ctx, ctxCancel)
}
func (w *WorkerRelay) DisableWgWatcher() {
@@ -157,37 +157,51 @@ func (w *WorkerRelay) CloseConn() {
}
}
// wgStateCheck help to check the state of the wireguard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
expected := wgHandshakeOvertime
for {
select {
case <-timer.C:
lastHandshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
continue
}
w.log.Tracef("last handshake: %v", lastHandshake)
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
w.log.Debugf("WireGuard watcher started")
lastHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read wg stats: %v", err)
lastHandshake = time.Time{}
}
if time.Since(lastHandshake) > expected {
w.log.Infof("Wireguard handshake timed out, closing relay connection")
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
go func(lastHandshake time.Time) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
for {
select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
timer.Reset(wgHandshakeOvertime)
continue
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
return
}
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
lastHandshake = handshake
timer.Reset(resetTime)
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
}
resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
timer.Reset(resetTime)
expected = wgHandshakePeriod
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
}
}
}(lastHandshake)
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {

View File

@@ -1,4 +1,4 @@
package wgproxy
package ebpf
import (
"fmt"

View File

@@ -1,4 +1,4 @@
package wgproxy
package ebpf
import (
"fmt"

View File

@@ -1,6 +1,6 @@
//go:build linux && !android
package wgproxy
package ebpf
import (
"context"
@@ -13,47 +13,49 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
loopbackAddr = "127.0.0.1"
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
ctx context.Context
cancel context.CancelFunc
lastUsedPort uint16
localWGListenPort int
ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex
rawConn net.PacketConn
conn transport.UDPConn
lastUsedPort uint16
rawConn net.PacketConn
conn transport.UDPConn
ctx context.Context
ctxCancel context.CancelFunc
}
// NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
ebpfManager: ebpf.GetEbpfManagerInstance(),
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
return wgProxy
}
// listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) listen() error {
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
@@ -72,13 +74,14 @@ func (p *WGEBPFProxy) listen() error {
addr := net.UDPAddr{
Port: wgPorxyPort,
IP: net.ParseIP("127.0.0.1"),
IP: net.ParseIP(loopbackAddr),
}
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
conn, err := nbnet.ListenUDP("udp", &addr)
if err != nil {
cErr := p.Free()
if cErr != nil {
if cErr := p.Free(); cErr != nil {
log.Errorf("Failed to close the wgproxy: %s", cErr)
}
return err
@@ -91,108 +94,114 @@ func (p *WGEBPFProxy) listen() error {
}
// AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
go p.proxyToLocal(wgEndpointPort, turnConn)
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
IP: net.ParseIP(loopbackAddr),
Port: int(wgEndpointPort),
}
return wgEndpoint, nil
}
// CloseConn doing nothing because this type of proxy implementation does not store the connection
func (p *WGEBPFProxy) CloseConn() error {
return nil
}
// Free resources
// Free resources except the remoteConns will be keep open.
func (p *WGEBPFProxy) Free() error {
log.Debugf("free up ebpf wg proxy")
var err1, err2, err3 error
if p.conn != nil {
err1 = p.conn.Close()
if p.ctx != nil && p.ctx.Err() != nil {
//nolint
return nil
}
err2 = p.ebpfManager.FreeWGProxy()
if p.rawConn != nil {
err3 = p.rawConn.Close()
p.ctxCancel()
var result *multierror.Error
if p.conn != nil { // p.conn will be nil if we have failed to listen
if err := p.conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if err1 != nil {
return err1
if err := p.ebpfManager.FreeWGProxy(); err != nil {
result = multierror.Append(result, err)
}
if err2 != nil {
return err2
if err := p.rawConn.Close(); err != nil {
result = multierror.Append(result, err)
}
return err3
return nberrors.FormatErrorOrNil(result)
}
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
var err error
defer func() {
p.removeTurnConn(endpointPort)
}()
for {
select {
case <-p.ctx.Done():
return
default:
var n int
n, err = remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
for ctx.Err() == nil {
n, err = remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
err = p.sendPkg(buf[:n], endpointPort)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
for p.ctx.Err() == nil {
if err := p.readAndForwardPacket(buf); err != nil {
if p.ctx.Err() != nil {
return
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
continue
}
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
log.Errorf("failed to proxy packet to remote conn: %s", err)
}
}
}
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
return fmt.Errorf("failed to read UDP packet from WG: %w", err)
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
if p.ctx.Err() == nil {
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
}
return nil
}
if _, err := conn.Write(buf[:n]); err != nil {
return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
}
return nil
}
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
@@ -206,11 +215,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
}
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID)
_, ok := p.turnConnStore[turnConnID]
if ok {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
}
delete(p.turnConnStore, turnConnID)
}
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {

View File

@@ -1,14 +1,13 @@
//go:build linux && !android
package wgproxy
package ebpf
import (
"context"
"testing"
)
func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1)
wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil)
if p != 1 {
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1)
wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1)
wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil)

View File

@@ -0,0 +1,44 @@
//go:build linux && !android
package ebpf
import (
"context"
"fmt"
"net"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
}
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
ctxConn, cancel := context.WithCancel(ctx)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
if err != nil {
cancel()
return nil, fmt.Errorf("add turn conn: %w", err)
}
e.remoteConn = remoteConn
e.cancel = cancel
return addr, err
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
return fmt.Errorf("proxy not started")
}
e.cancel()
if err := e.remoteConn.Close(); err != nil {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
}

View File

@@ -1,22 +0,0 @@
package wgproxy
import "context"
type Factory struct {
wgPort int
ebpfProxy Proxy
}
func (w *Factory) GetProxy(ctx context.Context) Proxy {
if w.ebpfProxy != nil {
return w.ebpfProxy
}
return NewWGUserSpaceProxy(ctx, w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy != nil {
return w.ebpfProxy.Free()
}
return nil
}

View File

@@ -3,20 +3,26 @@
package wgproxy
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
)
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
type Factory struct {
wgPort int
ebpfProxy *ebpf.WGEBPFProxy
}
func NewFactory(userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
if userspace {
return f
}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen()
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
err := ebpfProxy.Listen()
if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
@@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f.ebpfProxy = ebpfProxy
return f
}
func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil {
p := &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
return p
}
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}

View File

@@ -2,8 +2,20 @@
package wgproxy
import "context"
import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
type Factory struct {
wgPort int
}
func NewFactory(_ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}
func (w *Factory) GetProxy() Proxy {
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
return nil
}

View File

@@ -1,12 +1,12 @@
package wgproxy
import (
"context"
"net"
)
// Proxy is a transfer layer between the Turn connection and the WireGuard
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(turnConn net.Conn) (net.Addr, error)
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
CloseConn() error
Free() error
}

View File

@@ -0,0 +1,128 @@
//go:build linux
package wgproxy
import (
"context"
"io"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
"github.com/netbirdio/netbird/util"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
type mocConn struct {
closeChan chan struct{}
closed bool
}
func newMockConn() *mocConn {
return &mocConn{
closeChan: make(chan struct{}),
}
}
func (m *mocConn) Read(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Write(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Close() error {
if m.closed == true {
return nil
}
m.closed = true
close(m.closeChan)
return nil
}
func (m *mocConn) LocalAddr() net.Addr {
panic("implement me")
}
func (m *mocConn) RemoteAddr() net.Addr {
return &net.UDPAddr{
IP: net.ParseIP("172.16.254.1"),
}
}
func (m *mocConn) SetDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
proxy Proxy
}{
{
name: "userspace proxy",
proxy: usp.NewWGUserSpaceProxy(51830),
},
}
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@@ -1,120 +0,0 @@
package wgproxy
import (
"context"
"fmt"
"io"
"net"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
p.ctx, p.cancel = context.WithCancel(ctx)
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn
var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
p.cancel()
if p.localConn == nil {
return nil
}
return p.localConn.Close()
}
// Free doing nothing because this implementation of proxy does not have global state
func (p *WGUserSpaceProxy) Free() error {
return nil
}
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.localConn.Read(buf)
if err != nil {
log.Debugf("failed to read from wg interface conn: %s", err)
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if err == io.EOF {
p.cancel()
} else {
log.Debugf("failed to write to remote conn: %s", err)
}
continue
}
}
}
}
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WGUserSpaceProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
if err == io.EOF {
p.cancel()
return
}
log.Errorf("failed to read from remote conn: %s", err)
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}
}

View File

@@ -0,0 +1,146 @@
package usp
import (
"context"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
closeMu sync.Mutex
closed bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
p.ctx, p.cancel = context.WithCancel(ctx)
p.remoteConn = remoteConn
var err error
dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *WGUserSpaceProxy) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
// prevent double close
if p.closed {
return nil
}
p.closed = true
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}
if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
}
return errors.FormatErrorOrNil(result)
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
return
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to remote conn: %s", err)
return
}
}
}
// proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.remoteConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}

View File

@@ -160,7 +160,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(otel.Meter(""))
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)

8
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/netbirdio/netbird
go 1.21.0
go 1.23.0
require (
cunicu.li/go-rosenpass v0.4.0
@@ -59,8 +59,8 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -232,7 +232,7 @@ require (
k8s.io/apimachinery v0.26.2 // indirect
)
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240904111318-17777758453a
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949

12
go.sum
View File

@@ -521,12 +521,12 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a h1:2EcDFDT39Odz5EC38pOSyjCd3bLUjPi7pMQpH6k+zzk=
github.com/netbirdio/service v0.0.0-20240904111318-17777758453a/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=

View File

@@ -56,8 +56,9 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,

View File

@@ -64,8 +64,9 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,

View File

@@ -54,7 +54,7 @@ func Execute() error {
func init() {
stopCh = make(chan int)
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")

View File

@@ -20,11 +20,6 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
@@ -41,6 +36,10 @@ import (
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)
const (
@@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface {
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
GetAccount(ctx context.Context, accountID string) (*Account, error)
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
@@ -75,12 +75,14 @@ type AccountManager interface {
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -107,7 +109,7 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@@ -145,6 +147,7 @@ type AccountManager interface {
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
}
type DefaultAccountManager struct {
@@ -263,6 +266,16 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
// Subclass used in gorm to only load network and not whole account
type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
}
// AccountDNSSettings used in gorm to only load dns settings and not whole account
type AccountDNSSettings struct {
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
}
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
@@ -700,14 +713,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps
}
func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
@@ -734,14 +739,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList
}
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
@@ -1263,25 +1260,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil
}
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
// domain is used to create a new account if no account is found
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) {
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
// If an accountID is provided, it checks if the account exists and returns it.
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
// If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" {
return am.Store.GetAccount(ctx, accountID)
} else if userID != "" {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
if err != nil {
return nil, err
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
return account, nil
return accountID, nil
}
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
if userID != "" {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
return "", err
}
return account.Id, nil
}
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
}
func isNil(i idp.Manager) bool {
@@ -1624,13 +1633,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
}
// redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
// only possible with the enabled IdP manager
if am.idpManager == nil {
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
return nil
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := am.lookupUserInCache(ctx, userID, account)
if err != nil {
return err
@@ -1689,6 +1703,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
return am.Store.SaveAccount(ctx, account)
}
// GetAccount returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
if len(token) != PATLength {
@@ -1737,10 +1756,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
return account, user, pat, nil
}
// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
// GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountIDFromToken returns an account ID associated with this token.
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" {
return nil, nil, fmt.Errorf("user ID is empty")
return "", "", fmt.Errorf("user ID is empty")
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
@@ -1750,110 +1783,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
if err != nil {
return nil, nil, err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
alreadyUnlocked := false
defer func() {
if !alreadyUnlocked {
unlock()
}
}()
account, err := am.Store.GetAccount(ctx, newAcc.Id)
if err != nil {
return nil, nil, err
return "", "", err
}
user := account.Users[claims.UserId]
if user == nil {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil {
// this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, account, claims.UserId)
err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil {
return nil, nil, err
return "", "", err
}
}
if account.Settings.JWTGroupsEnabled {
if account.Settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
return account, user, nil
}
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if slice, ok := claim.([]interface{}); ok {
var groupsNames []string
for _, item := range slice {
if g, ok := item.(string); ok {
groupsNames = append(groupsNames, g)
} else {
log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
}
}
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// if groups were added or modified, save the account
if account.SetJWTGroups(claims.UserId, groupsNames) {
if account.Settings.GroupsPropagationEnabled {
if user, err := account.FindUser(claims.UserId); err == nil {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
} else {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
unlock()
alreadyUnlocked = true
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
}
} else {
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
}
}
}
} else {
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
}
} else {
log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
}
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
return "", "", err
}
return account, user, nil
return accountID, user.Id, nil
}
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
if settings == nil || !settings.JWTGroupsEnabled {
return nil
}
if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
return nil
}
// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// Update the account if group membership changes
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
if settings.GroupsPropagationEnabled {
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
}
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
return nil
}
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
return nil
}
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
// if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain
//
@@ -1870,26 +1904,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
return "", fmt.Errorf("user ID is empty")
}
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" {
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil {
return nil, err
return "", err
}
if _, ok := accountFromID.Users[claims.UserId]; !ok {
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
if userAccountID != claims.AccountId {
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
return accountFromID, nil
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
return userAccountID, nil
}
}
@@ -1899,48 +1941,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
// We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return nil, err
return "", err
}
}
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err == nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
defer unlockAccount()
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil {
return nil, err
return "", err
}
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
if err != nil {
return nil, err
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
return "", err
}
return account, nil
return account.Id, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
var domainAccount *Account
if domainAccountID != "" {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
return nil, err
return "", err
}
}
return am.handleNewUserAccount(ctx, domainAccount, claims)
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
if err != nil {
return "", err
}
return account.Id, nil
} else {
// other error
return nil, err
return "", err
}
}
@@ -2033,26 +2080,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
// group propagation and set the list of groups with access permissions.
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
account, _, err := am.GetAccountFromToken(ctx, claims)
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil {
return err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := make([]string, 0)
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
}
}
}
}
if settings != nil && settings.JWTGroupsEnabled {
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
@@ -2082,7 +2124,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
}
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil {
return false, err
}
@@ -2103,6 +2145,38 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
return false, nil
}
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
labelMap := ConvertSliceToMap(existingLabels)
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
if err != nil {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
if newLabel == "" {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
return newLabel, nil
}
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
}
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
@@ -2185,6 +2259,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
return acc
}
// extractJWTGroups extracts the group names from a JWT token's claims.
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
userJWTGroups := make([]string, 0)
if claim, ok := claims.Raw[claimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
} else {
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
}
}
}
} else {
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
}
return userJWTGroups
}
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {

View File

@@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
assert.Equal(t, account.Id, ev.TargetID)
}
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims
type test struct {
@@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
@@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount = acc
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
@@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}
t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
@@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
@@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{}
@@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
if err != nil {
t.Fatal(err)
}
if account == nil {
if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId)
return
}
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
}
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
@@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
assert.NotNil(t, account.Settings)
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings)
assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
}
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
@@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
@@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})

View File

@@ -6,13 +6,14 @@ import (
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"errors"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct {
block cipher.Block
gcm cipher.AEAD
}
func GenerateKey() (string, error) {
@@ -35,14 +36,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
ec := &FieldEncrypt{
block: block,
gcm: gcm,
}
return ec, nil
}
func (ec *FieldEncrypt) Encrypt(payload string) string {
func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
@@ -50,7 +58,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string {
return base64.StdEncoding.EncodeToString(cipherText)
}
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
// Encrypt encrypts plaintext using AES-GCM
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
plaintext := []byte(payload)
nonceSize := ec.gcm.NonceSize()
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
if _, err := rand.Read(nonce); err != nil {
return "", err
}
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
@@ -65,17 +88,49 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
return string(payload), nil
}
// Decrypt decrypts ciphertext using AES-GCM
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
nonceSize := ec.gcm.NonceSize()
if len(cipherText) < nonceSize {
return "", errors.New("cipher text too short")
}
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
if err != nil {
return "", err
}
return string(plainText), nil
}
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
paddingLen := int(src[srcLen-1])
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
return nil, fmt.Errorf("padding size error")
if srcLen == 0 {
return nil, errors.New("input data is empty")
}
paddingLen := int(src[srcLen-1])
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
return nil, errors.New("invalid padding size")
}
// Verify that all padding bytes are the same
for i := 0; i < paddingLen; i++ {
if src[srcLen-1-i] != byte(paddingLen) {
return nil, errors.New("invalid padding")
}
}
return src[:srcLen-paddingLen], nil
}

View File

@@ -1,6 +1,7 @@
package sqlite
import (
"bytes"
"testing"
)
@@ -15,7 +16,11 @@ func TestGenerateKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
@@ -30,6 +35,32 @@ func TestGenerateKey(t *testing.T) {
}
}
func TestGenerateKeyLegacy(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.LegacyEncrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.LegacyDecrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
@@ -41,7 +72,11 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.Encrypt(testData)
encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
@@ -61,3 +96,215 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}
func TestEncryptDecrypt(t *testing.T) {
// Generate a key for encryption/decryption
key, err := GenerateKey()
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
// Initialize the FieldEncrypt with the generated key
ec, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("Failed to create FieldEncrypt: %v", err)
}
// Test cases
testCases := []struct {
name string
input string
}{
{
name: "Empty String",
input: "",
},
{
name: "Short String",
input: "Hello",
},
{
name: "String with Spaces",
input: "Hello, World!",
},
{
name: "Long String",
input: "The quick brown fox jumps over the lazy dog.",
},
{
name: "Unicode Characters",
input: "こんにちは世界",
},
{
name: "Special Characters",
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
},
{
name: "Numeric String",
input: "1234567890",
},
{
name: "Repeated Characters",
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
},
{
name: "Multi-block String",
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
},
{
name: "Non-ASCII and ASCII Mix",
input: "Hello 世界 123",
},
}
for _, tc := range testCases {
t.Run(tc.name+" - Legacy", func(t *testing.T) {
// Legacy Encryption
encryptedLegacy := ec.LegacyEncrypt(tc.input)
if encryptedLegacy == "" {
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
}
// Legacy Decryption
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
if err != nil {
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedLegacy != tc.input {
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
}
})
t.Run(tc.name+" - New", func(t *testing.T) {
// New Encryption
encryptedNew, err := ec.Encrypt(tc.input)
if err != nil {
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
}
if encryptedNew == "" {
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
}
// New Decryption
decryptedNew, err := ec.Decrypt(encryptedNew)
if err != nil {
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedNew != tc.input {
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
}
})
}
}
func TestPKCS5UnPadding(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectError bool
}{
{
name: "Valid Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
expected: []byte("Hello, World!"),
},
{
name: "Empty Input",
input: []byte{},
expectError: true,
},
{
name: "Padding Length Zero",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
expectError: true,
},
{
name: "Padding Length Exceeds Block Size",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
expectError: true,
},
{
name: "Padding Length Exceeds Input Length",
input: []byte{5, 5, 5},
expectError: true,
},
{
name: "Invalid Padding Bytes",
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
expectError: true,
},
{
name: "Valid Single Byte Padding",
input: append([]byte("Hello, World!"), byte(1)),
expected: []byte("Hello, World!"),
},
{
name: "Invalid Mixed Padding Bytes",
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
expectError: true,
},
{
name: "Valid Full Block Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("Hello, World!"),
},
{
name: "Non-Padding Byte at End",
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
expectError: true,
},
{
name: "Valid Padding with Different Text Length",
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
expected: []byte("Test"),
},
{
name: "Padding Length Equal to Input Length",
input: bytes.Repeat([]byte{8}, 8),
expected: []byte{},
},
{
name: "Invalid Padding Length Zero (Again)",
input: append([]byte("Test"), byte(0)),
expectError: true,
},
{
name: "Padding Length Greater Than Input",
input: []byte{10},
expectError: true,
},
{
name: "Input Length Not Multiple of Block Size",
input: append([]byte("Invalid Length"), byte(1)),
expected: []byte("Invalid Length"),
},
{
name: "Valid Padding with Non-ASCII Characters",
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
expected: []byte("こんにちは"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs5UnPadding(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Did not expect error but got: %v", err)
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("Expected output %v, got %v", tt.expected, result)
}
}
})
}
}

View File

@@ -0,0 +1,157 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
log "github.com/sirupsen/logrus"
)
func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
if _, err := db.Exec(createTableQuery); err != nil {
return err
}
if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
return err
}
if err := updateDeletedUsersTable(ctx, db); err != nil {
return fmt.Errorf("failed to update deleted_users table: %v", err)
}
return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
}
// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
exists, err := checkColumnExists(db, "deleted_users", "name")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
}
exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
}
return nil
}
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
// legacy CBC encryption with a static IV to the new GCM encryption method.
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %v", err)
}
defer func() {
_ = tx.Rollback()
}()
rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
if err != nil {
return fmt.Errorf("failed to execute select query: %v", err)
}
defer rows.Close()
updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
if err != nil {
return fmt.Errorf("failed to prepare update statement: %v", err)
}
defer updateStmt.Close()
if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %v", err)
}
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
return nil
}
// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
for rows.Next() {
var (
id, decryptedEmail, decryptedName string
email, name *string
)
err := rows.Scan(&id, &email, &name)
if err != nil {
return err
}
if email != nil {
decryptedEmail, err = crypt.LegacyDecrypt(*email)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt email: %w", err),
)
continue
}
}
if name != nil {
decryptedName, err = crypt.LegacyDecrypt(*name)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt name: %w", err),
)
continue
}
}
encryptedEmail, err := crypt.Encrypt(decryptedEmail)
if err != nil {
return fmt.Errorf("failed to encrypt email: %w", err)
}
encryptedName, err := crypt.Encrypt(decryptedName)
if err != nil {
return fmt.Errorf("failed to encrypt name: %w", err)
}
_, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
if err != nil {
return err
}
}
if err := rows.Err(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,84 @@
package sqlite
import (
"context"
"database/sql"
"path/filepath"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/stretchr/testify/require"
)
func setupDatabase(t *testing.T) *sql.DB {
t.Helper()
dbFile := filepath.Join(t.TempDir(), eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
require.NoError(t, err, "Failed to open database")
t.Cleanup(func() {
_ = db.Close()
})
_, err = db.Exec(createTableQuery)
require.NoError(t, err, "Failed to create events table")
_, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
require.NoError(t, err, "Failed to create deleted_users table")
return db
}
func TestMigrate(t *testing.T) {
db := setupDatabase(t)
key, err := GenerateKey()
require.NoError(t, err, "Failed to generate key")
crypt, err := NewFieldEncrypt(key)
require.NoError(t, err, "Failed to initialize FieldEncrypt")
legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
legacyName := crypt.LegacyEncrypt("Test Account")
_, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
require.NoError(t, err, "Failed to insert event")
_, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
require.NoError(t, err, "Failed to insert legacy encrypted data")
colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists")
require.False(t, colExists, "enc_algo column should not exist before migration")
err = migrate(context.Background(), crypt, db)
require.NoError(t, err, "Migration failed")
colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
require.True(t, colExists, "enc_algo column should exist after migration")
var encAlgo string
err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
require.NoError(t, err, "Failed to select updated data")
require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
store, err := createStore(crypt, db)
require.NoError(t, err, "Failed to create store")
events, err := store.Get(context.Background(), "accountID", 0, 1, false)
require.NoError(t, err, "Failed to get events")
require.Len(t, events, 1, "Should have one event")
require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
require.Equal(t, "targetID", events[0].TargetID, "target id should match")
require.Equal(t, "accountID", events[0].AccountID, "account id should match")
require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
}

View File

@@ -26,7 +26,7 @@ const (
"meta TEXT," +
" target_id TEXT);"
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events
@@ -69,10 +69,12 @@ const (
and some selfhosted deployments might have duplicates already so we need to clean the table first.
*/
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com"
gcmEncAlgo = "GCM"
)
// Store is the implementation of the activity.Store interface backed by SQLite
@@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return nil, err
}
_, err = db.Exec(createTableQuery)
if err != nil {
if err = migrate(ctx, crypt, db); err != nil {
_ = db.Close()
return nil, err
return nil, fmt.Errorf("events database migration: %w", err)
}
_, err = db.Exec(creatTableDeletedUsersQuery)
if err != nil {
_ = db.Close()
return nil, err
}
err = updateDeletedUsersTable(ctx, db)
if err != nil {
_ = db.Close()
return nil, err
}
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
s := &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}
return s, nil
return createStore(crypt, db)
}
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
@@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
return event.Meta, nil
}
encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName)
encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
if err != nil {
return nil, err
}
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
if err != nil {
return nil, err
}
_, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
if err != nil {
return nil, err
}
@@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error {
return nil
}
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
log.WithContext(ctx).Debugf("check deleted_users table version")
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
// createStore initializes and returns a new Store instance with prepared SQL statements.
func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
return err
_ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
return &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}, nil
}
// checkColumnExists checks if a column exists in a specified table
func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
rows, err := db.Query(query)
if err != nil {
return false, fmt.Errorf("failed to query table info: %w", err)
}
defer rows.Close()
found := false
for rows.Next() {
var (
cid int
name string
dataType string
notNull int
dfltVal sql.NullString
pk int
)
err := rows.Scan(&cid, &name, &dataType, &notNull, &dfltVal, &pk)
var cid int
var name, ctype string
var notnull, pk int
var dfltValue sql.NullString
err = rows.Scan(&cid, &name, &ctype, &notnull, &dfltValue, &pk)
if err != nil {
return err
return false, fmt.Errorf("failed to scan row: %w", err)
}
if name == "name" {
found = true
break
if name == columnName {
return true, nil
}
}
err = rows.Err()
if err != nil {
return err
if err = rows.Err(); err != nil {
return false, err
}
if found {
return nil
}
log.WithContext(ctx).Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
return false, nil
}

View File

@@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
}
dnsSettings := account.DNSSettings.Copy()
return &dnsSettings, nil
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
}
// SaveDNSSettings validates a user role and updates the account's DNS settings

View File

@@ -7,6 +7,7 @@ import (
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)
type MockStore struct {
@@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
return s.account, nil
}
return nil, fmt.Errorf("account not found")
return nil, status.NewPeerNotFoundError(peerId)
}
type MocAccountManager struct {

View File

@@ -2,20 +2,23 @@ package server
import (
"context"
"errors"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
@@ -46,6 +49,158 @@ type FileStore struct {
metrics telemetry.AppMetrics `json:"-"`
}
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
return f(s)
}
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
if !ok {
return status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return err
}
account.SetupKeys[setupKeyID].UsedTimes++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
allGroup, err := account.GetGroupAll()
if err != nil || allGroup == nil {
return errors.New("all group not found")
}
allGroup.Peers = append(allGroup.Peers, peerID)
return nil
}
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountId)
if err != nil {
return err
}
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
return nil
}
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[peer.AccountID]
if !ok {
return status.NewAccountNotFoundError(peer.AccountID)
}
account.Peers[peer.ID] = peer
return s.SaveAccount(ctx, account)
}
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[accountId]
if !ok {
return status.NewAccountNotFoundError(accountId)
}
account.Network.Serial++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
if !ok {
return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
setupKey, ok := account.SetupKeys[key]
if !ok {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return setupKey, nil
}
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
var takenIps []net.IP
for _, existingPeer := range account.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps, nil
}
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
existingLabels := []string{}
for _, peer := range account.Peers {
if peer.DNSLabel != "" {
existingLabels = append(existingLabels, peer.DNSLabel)
}
}
return existingLabels, nil
}
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Network, nil
}
type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir
@@ -422,7 +577,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
@@ -469,7 +624,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
@@ -480,10 +635,19 @@ func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, er
return nil, err
}
return account.Users[userID].Copy(), nil
user := account.Users[userID].Copy()
pat := make([]PersonalAccessToken, 0, len(user.PATs))
for _, token := range user.PATs {
if token != nil {
pat = append(pat, *token)
}
}
user.PATsG = pat
return user, nil
}
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
@@ -513,7 +677,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, ok := s.Accounts[accountID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewAccountNotFoundError(accountID)
}
return account, nil
@@ -639,13 +803,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
return "", status.NewSetupKeyNotFoundError()
}
return accountID, nil
}
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -668,7 +832,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
return nil, status.NewPeerNotFoundError(peerKey)
}
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -758,7 +922,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
}
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
s.mux.Lock()
defer s.mux.Unlock()
@@ -777,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.T
return nil
}
func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
}
@@ -796,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
}
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return "", "", err
}
return account.Domain, account.DomainCategory, nil
}
// AccountExists checks whether an account exists by the given ID.
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
_, exists := s.Accounts[id]
return exists, nil
}
func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
}
func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
}
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
}
func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
}
func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
}
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
}
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
}
func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
}
func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
}
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
}
func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
}
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
}
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
}

View File

@@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string {
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
}
// GetGroup object of the peers
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
return nil
}
// GetGroup returns a specific group by groupID in an account
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
group, ok := account.Groups[groupID]
if ok {
return group, nil
}
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
}
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
return am.Store.GetAccountGroups(ctx, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
matchingGroups := make([]*nbgroup.Group, 0)
for _, group := range account.Groups {
if group.Name == groupName {
matchingGroups = append(matchingGroups, group)
}
}
if len(matchingGroups) == 0 {
return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName)
}
maxPeers := -1
var groupWithMostPeers *nbgroup.Group
for i, group := range matchingGroups {
if len(group.Peers) > maxPeers {
maxPeers = len(group.Peers)
groupWithMostPeers = matchingGroups[i]
}
}
return groupWithMostPeers, nil
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
}
// SaveGroup object of the peers
@@ -262,6 +217,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return nil
}
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}

View File

@@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
_, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}

View File

@@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !(user.HasAdminPower() || user.IsServiceUser) {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(account)
resp := toAccountResponse(accountID, settings)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(updatedAccount)
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
util.WriteJSONObject(r.Context(), w, &resp)
}
// DeleteAccount is a HTTP DELETE handler to delete an account
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
@@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
func toAccountResponse(account *server.Account) *api.Account {
jwtAllowGroups := account.Settings.JWTAllowGroups
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
}
settings := api.AccountSettings{
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
apiSettings := api.AccountSettings{
PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &settings.JWTGroupsEnabled,
JwtGroupsClaimName: &settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups,
RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked,
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
}
if account.Settings.Extra != nil {
settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled}
if settings.Extra != nil {
apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
}
return &api.Account{
Id: account.Id,
Settings: settings,
Id: accountID,
Settings: apiSettings,
}
}

View File

@@ -23,8 +23,11 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return account, admin, nil
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return account.Id, admin.Id, nil
},
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
return account.Settings, nil
},
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour

View File

@@ -251,7 +251,7 @@ components:
- name
- ssh_enabled
- login_expiration_enabled
PeerBase:
Peer:
allOf:
- $ref: '#/components/schemas/PeerMinimum'
- type: object
@@ -378,25 +378,40 @@ components:
description: User ID of the user that enrolled this peer
type: string
example: google-oauth2|277474792786460067937
os:
description: Peer's operating system and version
type: string
example: linux
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:
$ref: '#/components/schemas/CityName'
geoname_id:
description: Unique identifier from the GeoNames database for a specific geographical location.
type: integer
example: 2643743
connected:
description: Peer to Management connection status
type: boolean
example: true
last_seen:
description: Last time peer connected to Netbird's management service
type: string
format: date-time
example: "2023-05-05T10:05:26.420578Z"
required:
- ip
- dns_label
- user_id
Peer:
allOf:
- $ref: '#/components/schemas/PeerBase'
- type: object
properties:
accessible_peers:
description: List of accessible peers
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
required:
- accessible_peers
- os
- country_code
- city_name
- geoname_id
- connected
- last_seen
PeerBatch:
allOf:
- $ref: '#/components/schemas/PeerBase'
- $ref: '#/components/schemas/Peer'
- type: object
properties:
accessible_peers_count:
@@ -935,7 +950,7 @@ components:
type: array
items:
type: string
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
action:
description: Action to take upon policy match
type: string
@@ -1806,6 +1821,38 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers/{peerId}/accessible-peers:
get:
summary: List accessible Peers
description: Returns a list of peers that the specified peer can connect to within the network.
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
responses:
'200':
description: A JSON Array of Accessible Peers
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/setup-keys:
get:
summary: List all Setup Keys

View File

@@ -152,18 +152,36 @@ const (
// AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct {
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// Name Peer's hostname
Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
}
@@ -490,81 +508,6 @@ type OSVersionCheck struct {
// Peer defines model for Peer.
type Peer struct {
// AccessiblePeers List of accessible peers
AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// ConnectionIp Peer's public connection IP address
ConnectionIp string `json:"connection_ip"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
// Hostname Hostname of the machine
Hostname string `json:"hostname"`
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
// KernelVersion Peer's operating system kernel version
KernelVersion string `json:"kernel_version"`
// LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
// LoginExpired Indicates whether peer's login expired or not
LoginExpired bool `json:"login_expired"`
// Name Peer's hostname
Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// SerialNumber System serial number
SerialNumber string `json:"serial_number"`
// SshEnabled Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"`
// UiVersion Peer's desktop UI version
UiVersion string `json:"ui_version"`
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
// Version Peer's daemon or cli version
Version string `json:"version"`
}
// PeerBase defines model for PeerBase.
type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`

View File

@@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
// UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
// GetAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e)
}
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
return &EventsHandler{
accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
@@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
}
return []*activity.Event{}, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
}, user, nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
@@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...)
handler := initEventsTestData(accountID, events...)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

View File

@@ -11,9 +11,9 @@ import (
"testing"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
"test_user": user,
},
}, user, nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
return server.NewAdminUser(id), nil
},
},
geolocationManager: geo,

View File

@@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r)
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
return err
}
user, err := l.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
return err
}

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"github.com/gorilla/mux"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
@@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
// GetAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
groupsResponse := make([]*api.Group, 0, len(groups))
for _, group := range groups {
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group))
}
util.WriteJSONObject(r.Context(), w, groupsResponse)
@@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
// UpdateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return
}
eg, ok := account.Groups[groupID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return
}
allGroup, err := account.GetGroupAll()
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return
@@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
ID: groupID,
Name: req.Name,
Peers: peers,
Issued: eg.Issued,
IntegrationReference: eg.IntegrationReference,
Issued: existingGroup.Issued,
IntegrationReference: existingGroup.IntegrationReference,
}
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
}
// CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
Issued: nbgroup.GroupIssuedAPI,
}
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
}
// DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
aID := account.Id
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
@@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return
}
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
if err != nil {
_, ok := err.(*server.GroupLinkError)
if ok {
@@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
// GetGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
switch r.Method {
case http.MethodGet:
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
}
func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group {
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
cache := make(map[string]api.PeerMinimum)
gr := api.Group{
Id: group.ID,
@@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
for _, pid := range group.Peers {
_, ok := cache[pid]
if !ok {
peer, ok := account.Peers[pid]
peer, ok := peersMap[pid]
if !ok {
continue
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux"
"github.com/magiconair/properties/assert"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
@@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
}
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
@@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return nil
},
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
if groupID != "idofthegroup" {
groups := map[string]*nbgroup.Group{
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
}
for _, group := range initGroups {
groups[group.ID] = group
}
group, ok := groups[groupID]
if !ok {
return nil, status.Errorf(status.NotFound, "not found")
}
if groupID == "id-jwt-group" {
return &nbgroup.Group{
ID: "id-jwt-group",
Name: "Default Group",
Issued: nbgroup.GroupIssuedJWT,
}, nil
}
return &nbgroup.Group{
ID: "idofthegroup",
Name: "Group",
Issued: nbgroup.GroupIssuedAPI,
}, nil
return group, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: TestPeers,
Users: map[string]*server.User{
user.Id: user,
},
Groups: map[string]*nbgroup.Group{
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
},
}, user, nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
if groupName == "All" {
return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
}
return nil, fmt.Errorf("unknown group name")
},
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil
},
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" {
@@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser, group)
p := initGroupTestData(group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

View File

@@ -82,7 +82,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
AuthCfg: authCfg,
}
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
@@ -115,6 +115,7 @@ func (apiHandler *apiHandler) addPeersEndpoint() {
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersEndpoint() {

View File

@@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
// CreateNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled,
}
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
// DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
return
}
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
// GetNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -18,7 +18,6 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -29,14 +28,6 @@ const (
testNSGroupAccountID = "test_id"
)
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: "super",
@@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
userID := vars["userId"]
targetUserID := vars["userId"]
if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
// GetToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
return
}
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
// CreateToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
return
}
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
}, nil
},
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testAccount, testAccount.Users[existingUserID], nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
@@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: testDomain,
AccountId: testNSGroupAccountID,
AccountId: existingAccountID,
}
}),
),

View File

@@ -7,8 +7,6 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -16,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
// PeersHandler is a handler that returns peers of the account
@@ -71,15 +70,11 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return
}
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
}
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -101,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
}
}
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
if err != nil {
util.WriteError(ctx, err, w)
return
@@ -117,13 +112,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
return
}
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
}
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -139,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string,
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -153,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodDelete:
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
h.deletePeer(r.Context(), accountID, userID, peerID, w)
return
case http.MethodPut:
h.updatePeer(r.Context(), account, user, peerID, w, r)
return
case http.MethodGet:
h.getPeer(r.Context(), account, peerID, user.Id, w)
case http.MethodGet, http.MethodPut:
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if r.Method == http.MethodGet {
h.getPeer(r.Context(), account, peerID, userID, w)
} else {
h.updatePeer(r.Context(), account, userID, peerID, w, r)
}
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@@ -168,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
// GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -188,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
for _, peer := range account.Peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -220,32 +213,93 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
}
}
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := account.FindUser(userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
// If the user is regular user and does not own the peer
// with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser {
peer, ok := account.Peers[peerID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
return
}
if peer.UserID != user.Id {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return
}
}
dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers {
ap := api.AccessiblePeer{
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
}
for _, p := range netMap.OfflinePeers {
ap := api.AccessiblePeer{
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
}
return accessiblePeers
}
func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePeer {
return api.AccessiblePeer{
CityName: peer.Location.CityName,
Connected: peer.Status.Connected,
CountryCode: peer.Location.CountryCode,
DnsLabel: fqdn(peer, dnsDomain),
GeonameId: int(peer.Location.GeoNameID),
Id: peer.ID,
Ip: peer.IP.String(),
LastSeen: peer.Status.LastSeen,
Name: peer.Name,
Os: peer.Meta.OS,
UserId: peer.UserID,
}
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{})
@@ -270,7 +324,7 @@ func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMi
return groupsInfo
}
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer {
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
@@ -296,7 +350,6 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
AccessiblePeers: accessiblePeer,
ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
@@ -12,20 +13,29 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const testPeerID = "test_peer"
const noUpdateChannelTestPeerID = "no-update-channel"
type ctxKey string
const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"
adminUser = "admin_user"
regularUser = "regular_user"
serviceUser = "service_user"
userIDKey ctxKey = "user_id"
)
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{
@@ -59,22 +69,61 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: map[string]*nbpeer.Peer{
peers[0].ID: peers[0],
peers[1].ID: peers[1],
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: accountID,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: accountID,
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
"test_user": user,
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
@@ -83,7 +132,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
Serial: 51,
},
}, user, nil
}
return account, nil
},
HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{})
@@ -99,8 +150,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
userID := r.Context().Value(userIDKey).(string)
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: userID,
Domain: "hotmail.com",
AccountId: "test_id",
}
@@ -197,6 +249,8 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
@@ -227,9 +281,15 @@ func TestGetPeers(t *testing.T) {
// hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2)
assert.Equal(t, respBody[1].Connected, false)
got = respBody[0]
for _, peer := range respBody {
if peer.Id == testPeerID {
got = peer
} else {
assert.Equal(t, peer.Connected, false)
}
}
} else {
got = &api.Peer{}
err = json.Unmarshal(content, got)
@@ -251,3 +311,119 @@ func TestGetPeers(t *testing.T) {
})
}
}
func TestGetAccessiblePeers(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
Key: "key1",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer1",
LoginExpirationEnabled: false,
UserID: regularUser,
}
peer2 := &nbpeer.Peer{
ID: "peer2",
Key: "key2",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer2",
LoginExpirationEnabled: false,
UserID: adminUser,
}
peer3 := &nbpeer.Peer{
ID: "peer3",
Key: "key3",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer3",
LoginExpirationEnabled: false,
UserID: regularUser,
}
p := initTestMetaData(peer1, peer2, peer3)
tt := []struct {
name string
peerID string
callerUserID string
expectedStatus int
expectedPeers []string
}{
{
name: "non admin user can access owned peer",
peerID: "peer1",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer2", "peer3"},
},
{
name: "non admin user can't access unowned peer",
peerID: "peer2",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "admin user can access owned peer",
peerID: "peer2",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer3"},
},
{
name: "admin user can access unowned peer",
peerID: "peer3",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
{
name: "service user can access unowned peer",
peerID: "peer3",
callerUserID: serviceUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
if res.StatusCode != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
defer res.Body.Close()
var accessiblePeers []api.AccessiblePeer
err = json.Unmarshal(body, &accessiblePeers)
if err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
peerIDs := make([]string, len(accessiblePeers))
for i, peer := range accessiblePeers {
peerIDs[i] = peer.Id
}
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"strconv"
"github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
@@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policies := []*api.Policy{}
for _, policy := range accountPolicies {
resp := toPolicyResponse(account, policy)
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policies := make([]*api.Policy, 0, len(listPolicies))
for _, policy := range listPolicies {
resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
@@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
// UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return
}
policyIdx := -1
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}
h.savePolicy(w, r, account, user, policyID)
}
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
h.savePolicy(w, r, account, user, "")
h.savePolicy(w, r, accountID, userID, policyID)
}
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
h.savePolicy(w, r, accountID, userID, "")
}
// savePolicy handles policy creation and update
func (h *Policies) savePolicy(
w http.ResponseWriter,
r *http.Request,
account *server.Account,
user *server.User,
policyID string,
) {
func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
var req api.PutApiPoliciesPolicyIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -127,6 +122,8 @@ func (h *Policies) savePolicy(
return
}
isUpdate := policyID != ""
if policyID == "" {
policyID = xid.New().String()
}
@@ -141,8 +138,8 @@ func (h *Policies) savePolicy(
pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: rule.Name,
Destinations: groupMinimumsToStrings(account, rule.Destinations),
Sources: groupMinimumsToStrings(account, rule.Sources),
Destinations: rule.Destinations,
Sources: rule.Sources,
Bidirectional: rule.Bidirectional,
}
@@ -207,15 +204,21 @@ func (h *Policies) savePolicy(
}
if req.SourcePostureChecks != nil {
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
policy.SourcePostureChecks = *req.SourcePostureChecks
}
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(account, &policy)
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(allGroups, &policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
@@ -227,12 +230,11 @@ func (h *Policies) savePolicy(
// DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
aID := account.Id
vars := mux.Vars(r)
policyID := vars["policyId"]
@@ -241,7 +243,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
return
}
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -252,40 +254,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
// GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
switch r.Method {
case http.MethodGet:
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
}
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
groupsMap := make(map[string]*nbgroup.Group)
for _, group := range groups {
groupsMap[group.ID] = group
}
cache := make(map[string]api.GroupMinimum)
ap := &api.Policy{
Id: &policy.ID,
@@ -306,16 +314,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
Protocol: api.PolicyRuleProtocol(r.Protocol),
Action: api.PolicyRuleAction(r.Action),
}
if len(r.Ports) != 0 {
portsCopy := r.Ports
rule.Ports = &portsCopy
}
for _, gid := range r.Sources {
_, ok := cache[gid]
if ok {
continue
}
if group, ok := account.Groups[gid]; ok {
if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
@@ -325,13 +335,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
cache[gid] = minimum
}
}
for _, gid := range r.Destinations {
cachedMinimum, ok := cache[gid]
if ok {
rule.Destinations = append(rule.Destinations, cachedMinimum)
continue
}
if group, ok := account.Groups[gid]; ok {
if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
@@ -345,28 +356,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
}
return ap
}
func groupMinimumsToStrings(account *server.Account, gm []string) []string {
result := make([]string, 0, len(gm))
for _, g := range gm {
if _, ok := account.Groups[g]; !ok {
continue
}
result = append(result, g)
}
return result
}
func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
}

View File

@@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return policy, nil
},
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
user := server.NewAdminUser(userID)
return &server.Account{
Id: claims.AccountId,
Id: accountID,
Domain: "hotmail.com",
Policies: []*server.Policy{
{ID: "id-existed"},
@@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
Users: map[string]*server.User{
"test_user": user,
},
}, user, nil
}, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
// GetAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
postureChecks := []*api.PostureCheck{}
for _, postureCheck := range accountPostureChecks {
postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks))
for _, postureCheck := range listPostureChecks {
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
}
@@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
// UpdatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
return
}
postureChecksIdx := -1
for i, postureCheck := range account.PostureChecks {
if postureCheck.ID == postureChecksID {
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
p.savePostureChecks(w, r, account, user, postureChecksID)
p.savePostureChecks(w, r, accountID, userID, postureChecksID)
}
// CreatePostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
p.savePostureChecks(w, r, account, user, "")
p.savePostureChecks(w, r, accountID, userID, "")
}
// GetPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
return
}
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
// DeletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
return
}
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
}
// savePostureChecks handles posture checks create and update
func (p *PostureChecksHandler) savePostureChecks(
w http.ResponseWriter,
r *http.Request,
account *server.Account,
user *server.User,
postureChecksID string,
) {
func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
var (
err error
req api.PostureCheckUpdate
@@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
return
}
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
"test_user": user,
},
PostureChecks: postureChecks,
}, user, nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
},
geolocationManager: &geolocation.Geolocation{},

View File

@@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
// GetAllRoutes returns the list of routes for the account
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
// CreateRoute handles route creation request
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
peerGroupIds = *req.PeerGroups
}
// Do not allow non-Linux peers
if peer := account.GetPeer(peerId); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
return
}
}
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
// UpdateRoute handles update to a route identified by a given ID
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
peerID = *req.Peer
}
// do not allow non Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
return
}
}
newRoute := &route.Route{
ID: route.ID(routeID),
NetID: route.NetID(req.NetworkId),
@@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.PeerGroups = *req.PeerGroups
}
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
// DeleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
// GetRoute handles a route Get request identified by ID
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return
}
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
return

View File

@@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler {
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
}
if peerID != "" {
if peerID == nonLinuxExistingPeerID {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
return &route.Route{
ID: existingRouteID,
NetID: netID,
@@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler {
if r.Peer == notFoundPeerID {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
}
if r.Peer == nonLinuxExistingPeerID {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
return nil
},
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
@@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler {
}
return nil
},
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, testingAccount.Users["test_user"], nil
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
//return testingAccount, testingAccount.Users["test_user"], nil
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
// CreateSetupKey is a POST requests that creates a new SetupKey
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
if req.Ephemeral != nil {
ephemeral = *req.Ephemeral
}
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userID, ephemeral)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
// GetSetupKey is a GET request to get a SetupKey by ID
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKey is a PUT request to update server.SetupKey
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey.Name = req.Name
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
// GetAllSetupKeys is a GET request that returns a list of SetupKey
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
@@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
) *SetupKeysHandler {
return &SetupKeysHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
Groups: map[string]*nbgroup.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"},
},
}, user, nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool,

View File

@@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
userID := vars["userId"]
if len(userID) == 0 {
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
existingUser, ok := account.Users[userID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return
}
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
Id: userID,
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked,
@@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name
}
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
Email: email,
Name: name,
Role: req.Role,
@@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
func initUsersTestData() *UsersHandler {
return &UsersHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return usersTestAccount.Id, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
return usersTestAccount.Users[id], nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)

View File

@@ -2,10 +2,12 @@ package idp
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"slices"
"strings"
"sync"
"time"
@@ -97,6 +99,42 @@ type zitadelUserResponse struct {
PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"`
}
// readZitadelError parses errors returned by the zitadel APIs from a response.
func readZitadelError(body io.ReadCloser) error {
bodyBytes, err := io.ReadAll(body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
helper := JsonParser{}
var target map[string]interface{}
err = helper.Unmarshal(bodyBytes, &target)
if err != nil {
return fmt.Errorf("error unparsable body: %s", string(bodyBytes))
}
// ensure keys are ordered for consistent logging behaviour.
errorKeys := make([]string, 0, len(target))
for k := range target {
errorKeys = append(errorKeys, k)
}
slices.Sort(errorKeys)
var errsOut []string
for _, k := range errorKeys {
if _, isEmbedded := target[k].(map[string]interface{}); isEmbedded {
continue
}
errsOut = append(errsOut, fmt.Sprintf("%s: %v", k, target[k]))
}
if len(errsOut) == 0 {
return errors.New("unknown error")
}
return errors.New(strings.Join(errsOut, " "))
}
// NewZitadelManager creates a new instance of the ZitadelManager.
func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
@@ -176,7 +214,8 @@ func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Respon
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode)
zErr := readZitadelError(resp.Body)
return nil, fmt.Errorf("unable to get zitadel token, statusCode %d, zitadel: %w", resp.StatusCode, zErr)
}
return resp, nil
@@ -489,7 +528,9 @@ func (zm *ZitadelManager) post(ctx context.Context, resource string, body string
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
zErr := readZitadelError(resp.Body)
return nil, fmt.Errorf("unable to post %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr)
}
return io.ReadAll(resp.Body)
@@ -561,7 +602,9 @@ func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
zErr := readZitadelError(resp.Body)
return nil, fmt.Errorf("unable to get %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr)
}
return io.ReadAll(resp.Body)

View File

@@ -66,7 +66,6 @@ func TestNewZitadelManager(t *testing.T) {
}
func TestZitadelRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct {
name string
inputCode int
@@ -88,15 +87,14 @@ func TestZitadelRequestJWTToken(t *testing.T) {
requestJWTTokenTestCase2 := requestJWTTokenTest{
name: "Request Bad Status Code",
inputCode: 400,
inputRespBody: "{}",
inputRespBody: "{\"error\": \"invalid_scope\", \"error_description\":\"openid missing\"}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: error: invalid_scope error_description: openid missing"),
expectedToken: "",
}
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputRespBody,
code: testCase.inputCode,
@@ -156,7 +154,7 @@ func TestZitadelParseRequestJWTResponse(t *testing.T) {
}
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
name: "Parse Bad json JWT Body",
inputRespBody: "",
inputRespBody: "{}",
helper: JsonParser{},
expectedToken: "",
expectedExpiresIn: 0,
@@ -254,7 +252,7 @@ func TestZitadelAuthenticate(t *testing.T) {
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: unknown error"),
expectedCode: 200,
expectedToken: "",
}

View File

@@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
@@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
}
func Test_LoginPerformance(t *testing.T) {
if os.Getenv("CI") == "true" {
t.Skip("Skipping on CI")
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("Skipping test on CI or Windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
@@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
// {"M", 250, 1},
// {"L", 500, 1},
// {"XL", 750, 1},
{"XXL", 2000, 1},
{"XXL", 5000, 1},
}
log.SetOutput(io.Discard)
@@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) {
}
defer mgmtServer.GracefulStop()
t.Logf("management setup complete, start registering peers")
var counter int32
var counterStart int32
var wg sync.WaitGroup
var wgAccount sync.WaitGroup
var mu sync.Mutex
messageCalls := []func() error{}
for j := 0; j < bc.accounts; j++ {
wg.Add(1)
wgAccount.Add(1)
var wgPeer sync.WaitGroup
go func(j int, counter *int32, counterStart *int32) {
defer wg.Done()
defer wgAccount.Done()
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
if err != nil {
@@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) {
return
}
startTime := time.Now()
for i := 0; i < bc.peers; i++ {
wgPeer.Add(1)
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Logf("failed to generate key: %v", err)
@@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) {
mu.Lock()
messageCalls = append(messageCalls, login)
mu.Unlock()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
go func(peerLogin PeerLogin, counterStart *int32) {
defer wgPeer.Done()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
}(peerLogin, counterStart)
}
wgPeer.Wait()
t.Logf("Time for registration: %s", time.Since(startTime))
}(j, &counter, &counterStart)
}
wg.Wait()
wgAccount.Wait()
t.Logf("prepared %d login calls", len(messageCalls))
testLoginPerformance(t, messageCalls)

View File

@@ -23,10 +23,11 @@ import (
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -48,7 +49,7 @@ type MockAccountManager struct {
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
@@ -79,7 +80,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string
@@ -105,6 +106,9 @@ type MockAccountManager struct {
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
@@ -190,16 +194,14 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountID(
ctx context.Context, userId, accountId, domain string,
) (*server.Account, error) {
if am.GetAccountByUserOrAccountIdFunc != nil {
return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
if am.GetAccountIDByUserOrAccountIdFunc != nil {
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
}
return nil, status.Errorf(
return "", status.Errorf(
codes.Unimplemented,
"method GetAccountByUserOrAccountID is not implemented",
"method GetAccountIDByUserOrAccountID is not implemented",
)
}
@@ -377,9 +379,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
}
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error {
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
if am.SavePolicyFunc != nil {
return am.SavePolicyFunc(ctx, accountID, userID, policy)
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
}
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
}
@@ -601,14 +603,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
error,
) {
if am.GetAccountFromTokenFunc != nil {
return am.GetAccountFromTokenFunc(ctx, claims)
// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if am.GetAccountIDFromTokenFunc != nil {
return am.GetAccountIDFromTokenFunc(ctx, claims)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
}
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
@@ -802,3 +802,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
}
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
}
// GetAccountByID mocks GetAccountByID of the AccountManager interface
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
if am.GetAccountByIDFunc != nil {
return am.GetAccountByIDFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented")
}
// GetUserByID mocks GetUserByID of the AccountManager interface
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
if am.GetUserByIDFunc != nil {
return am.GetUserByIDFunc(ctx, id)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
}
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
if am.GetAccountSettingsFunc != nil {
return am.GetAccountSettingsFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
}
func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) {
if am.GetAccountFunc != nil {
return am.GetAccountFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
}

View File

@@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
}
if !(user.HasAdminPower() || user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups")
}
nsGroup, found := account.NameServerGroups[nsGroupID]
if found {
return nsGroup.Copy(), nil
}
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
}
// CreateNameServerGroup creates and saves a new nameserver group
@@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
}
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
for _, item := range account.NameServerGroups {
nsGroups = append(nsGroups, item.Copy())
}
return nsGroups, nil
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {

View File

@@ -11,6 +11,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto"
@@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
}()
var account *Account
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.lookupUserInCache(ctx, userID, account)
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
_, err = account.FindPeerByPubKey(peer.Key)
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
Timestamp: time.Now().UTC(),
AccountID: account.Id,
AccountID: accountID,
}
var ephemeral bool
setupKeyName := ""
if !addedByUser {
// validate the setup key if adding with a key
sk, err := account.FindSetupKey(upperKey)
if err != nil {
return nil, nil, nil, err
}
var newPeer *nbpeer.Peer
if !sk.IsValid() {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
account.SetupKeys[sk.Key] = sk.IncrementUsage()
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
ephemeral = sk.Ephemeral
setupKeyName = sk.Name
} else {
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
}
takenIps := account.getTakenIPs()
existingLabels := account.getPeerDNSLabels()
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
if err != nil {
return nil, nil, nil, err
}
peer.DNSLabel = newLabel
network := account.Network
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, nil, nil, err
}
registrationTime := time.Now().UTC()
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
Key: peer.Key,
SetupKey: upperKey,
IP: nextIp,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: newLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var groupsToAdd []string
var setupKeyID string
var setupKeyName string
var ephemeral bool
if addedByUser {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
if err != nil {
return fmt.Errorf("failed to get user groups: %w", err)
}
groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
// Validate the setup key
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
// add peer to 'All' group
group, err := account.GetGroupAll()
if err != nil {
return nil, nil, nil, err
}
group.Peers = append(group.Peers, newPeer.ID)
if !sk.IsValid() {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
var groupsToAdd []string
if addedByUser {
groupsToAdd, err = account.getUserGroups(userID)
if err != nil {
return nil, nil, nil, err
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
groupsToAdd = sk.AutoGroups
ephemeral = sk.Ephemeral
setupKeyID = sk.Id
setupKeyName = sk.Name
}
} else {
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
if err != nil {
return nil, nil, nil, err
}
}
if len(groupsToAdd) > 0 {
for _, s := range groupsToAdd {
if g, ok := account.Groups[s]; ok && g.Name != "All" {
g.Peers = append(g.Peers, newPeer.ID)
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
if addedByUser {
user, err := account.FindUser(userID)
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
return fmt.Errorf("failed to get free DNS label: %w", err)
}
user.updateLastLogin(newPeer.LastLogin)
}
account.Peers[newPeer.ID] = newPeer
account.Network.IncSerial()
err = am.Store.SaveAccount(ctx, account)
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
if err != nil {
return fmt.Errorf("failed to get free IP: %w", err)
}
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAccount(ctx, newPeer)
if err != nil {
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
if err != nil {
return fmt.Errorf("failed to update user last login: %w", err)
}
} else {
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
}
// Account is saved, we can release the lock
unlock()
unlock = nil
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
unlock()
unlock = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}
postureChecks := am.getPeerPostureChecks(account, peer)
postureChecks := am.getPeerPostureChecks(account, newPeer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
}
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err)
}
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
}
return nextIp, nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
@@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}()
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, accountID)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
if err != nil {
return err
}
@@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, accountID)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
@@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
return err
}
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
if err != nil {
return err
}
@@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait()
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
labelMap[label] = struct{}{}
}
return labelMap
}

View File

@@ -7,20 +7,24 @@ import (
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
nbroute "github.com/netbirdio/netbird/route"
)
@@ -247,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -295,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
policy.Enabled = false
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) {
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
}
func Test_RegisterPeerByUser(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
UserID: existingUserID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
LastLogin: time.Now(),
}
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
}
func Test_RegisterPeerBySetupKey(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err)
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.Error(t, err)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.NotContains(t, account.Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
assert.Equal(t, uint64(0), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
}

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
_ "embed"
"slices"
"strconv"
"strings"
@@ -314,34 +315,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
for _, policy := range account.Policies {
if policy.ID == policyID {
return policy, nil
}
}
return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
}
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -350,7 +337,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err
}
exists := am.savePolicy(account, policy)
if err = am.savePolicy(account, policy, isUpdate); err != nil {
return err
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
@@ -358,7 +347,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
action := activity.PolicyAdded
if exists {
if isUpdate {
action = activity.PolicyUpdated
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
@@ -397,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
if !(user.HasAdminPower() || user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
}
return account.Policies, nil
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
@@ -434,18 +415,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
return policy, nil
}
func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) {
for i, p := range account.Policies {
if p.ID == policy.ID {
account.Policies[i] = policy
exists = true
break
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 {
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
}
// Update the existing policy
account.Policies[policyIdx] = policyToSave
return nil
}
if !exists {
account.Policies = append(account.Policies, policy)
}
return
// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)
return nil
}
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
@@ -560,3 +557,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
}
return nil
}
// filterValidPostureChecks filters and returns the posture check IDs from the given list
// that are valid within the provided account.
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
}
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
result := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
if _, exists := account.Groups[groupID]; exists {
result = append(result, groupID)
}
}
return result
}

View File

@@ -15,30 +15,16 @@ const (
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() {
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
for _, postureChecks := range account.PostureChecks {
if postureChecks.ID == postureChecksID {
return postureChecks, nil
}
}
return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
@@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
}
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() {
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
return account.PostureChecks, nil
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {

View File

@@ -17,29 +17,16 @@ import (
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
wantedRoute, found := account.Routes[routeID]
if found {
return wantedRoute, nil
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
}
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
@@ -134,6 +121,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err
}
// Do not allow non-Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
@@ -234,6 +228,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
// Do not allow non-Linux peers
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
@@ -311,29 +312,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
routes := make([]*route.Route, 0, len(account.Routes))
for _, item := range account.Routes {
routes = append(routes, item)
}
return routes, nil
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
}
func toProtocolRoute(route *route.Route) *proto.Route {

View File

@@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)

View File

@@ -330,26 +330,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
}
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() && !user.IsServiceUser {
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
}
keys := make([]*SetupKey, 0, len(account.SetupKeys))
for _, key := range account.SetupKeys {
keys := make([]*SetupKey, 0, len(setupKeys))
for _, key := range setupKeys {
var k *SetupKey
if !(user.HasAdminPower() || user.IsServiceUser) {
if !user.IsAdminOrServiceUser() {
k = key.HiddenCopy(999)
} else {
k = key.Copy()
@@ -362,44 +360,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() && !user.IsServiceUser {
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
}
var foundKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyID {
foundKey = key.Copy()
break
}
}
if foundKey == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
if foundKey.UpdatedAt.IsZero() {
foundKey.UpdatedAt = foundKey.CreatedAt
if setupKey.UpdatedAt.IsZero() {
setupKey.UpdatedAt = setupKey.CreatedAt
}
if !(user.HasAdminPower() || user.IsServiceUser) {
foundKey = foundKey.HiddenCopy(999)
if !user.IsAdminOrServiceUser() {
setupKey = setupKey.HiddenCopy(999)
}
return foundKey, nil
return setupKey, nil
}
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
@@ -33,7 +34,9 @@ import (
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
)
@@ -397,31 +400,40 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
var account Account
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if err != nil {
return nil, err
}
// TODO: rework to not call GetAccount
return s.GetAccount(ctx, account.Id)
return s.GetAccount(ctx, accountID)
}
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory,
).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
}
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
return nil, status.NewSetupKeyNotFoundError()
}
if key.AccountID == "" {
@@ -474,15 +486,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.First(&user, idQueryCondition, userID)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
return nil, status.NewUserNotFoundError(userID)
}
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting user from store")
return nil, status.NewGetUserFromStoreError()
}
return &user, nil
@@ -490,7 +502,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, e
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, idQueryCondition, accountID)
result := s.db.Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -535,7 +547,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -595,7 +607,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -612,12 +624,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -631,12 +642,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -650,12 +660,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
@@ -663,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var user User
var accountID string
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -677,61 +685,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var key SetupKey
var accountID string
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting setup key from store")
return "", status.NewSetupKeyNotFoundError()
}
if accountID == "" {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return accountID, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
}
// Convert the JSON strings to net.IP objects
ips := make([]net.IP, len(ipJSONStrings))
for i, ipJSON := range ipJSONStrings {
var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
}
ips[i] = ip
}
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
}
return labels, nil
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store")
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer
result := s.db.First(&peer, "key = ?", peerKey)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}
return &peer, nil
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
return status.NewUserNotFoundError(userID)
}
return status.Errorf(status.Internal, "issue getting user from store")
return status.NewGetUserFromStoreError()
}
user.LastLogin = lastLogin
return s.db.Save(user).Error
return s.db.Save(&user).Error
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@@ -850,3 +914,276 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return nil, status.NewSetupKeyNotFoundError()
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
"last_used": time.Now(),
})
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
}
return nil
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All'")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'")
}
return nil
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
return status.Errorf(status.Internal, "issue finding group")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
}
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group")
}
return nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
}
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count")
}
return nil
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return tx.Error
}
repo := s.withTx(tx)
err := operation(repo)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
}
}
func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
var accountDNSSettings AccountDNSSettings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found")
}
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
}
return &accountDNSSettings.DNSSettings, nil
}
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("id").First(&accountID, idQueryCondition, id)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return false, nil
}
return false, result.Error
}
return accountID != "", nil
}
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
var account Account
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", "", status.Errorf(status.NotFound, "account not found")
}
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
}
return account.Domain, account.DomainCategory, nil
}
// GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
}
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
}
return &group, nil
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
}
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
}
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
}
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
}
return record, nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
var record T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
}
return &record, nil
}

View File

@@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
func TestSqlite_GetTakenIPs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []net.IP{}, takenIPs)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip1 := net.IP{1, 1, 1, 1}.To16()
assert.Equal(t, []net.IP{ip1}, takenIPs)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
IP: net.IP{2, 2, 2, 2},
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip2 := net.IP{2, 2, 2, 2}.To16()
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
DNSLabel: "peer2.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip := net.IP{100, 64, 0, 0}.To16()
assert.Equal(t, ip, network.Net.IP)
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
assert.Equal(t, "", network.Dns)
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
assert.Equal(t, uint64(0), network.Serial)
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
assert.Equal(t, "Default key", setupKey.Name)
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 0, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 1, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}

View File

@@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError() error {
return Errorf(NotFound, "setup key not found")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store")
}

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
"github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
@@ -27,47 +28,94 @@ import (
"github.com/netbirdio/netbird/route"
)
type LockingStrength string
const (
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
)
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
DeleteAccount(ctx context.Context, account *Account) error
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(peerKey string) (string, error)
GetAccountIDByUserID(userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func()
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
}
type StoreEngine string

Some files were not shown because too many files have changed in this diff Show More