mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-30 06:06:38 +00:00
Compare commits
18 Commits
debug-0.29
...
fix/engine
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
933cf1c84d | ||
|
|
5eb936b49e | ||
|
|
24c0aaa745 | ||
|
|
16179db599 | ||
|
|
e27f85b317 | ||
|
|
2fd60b2cb4 | ||
|
|
3dca6099d4 | ||
|
|
cfbcf507fb | ||
|
|
52ae693c9e | ||
|
|
58ff7ab797 | ||
|
|
acb73bd64a | ||
|
|
4ebf6e1c4c | ||
|
|
1e4a0f77e2 | ||
|
|
b51d75204b | ||
|
|
e7d52c8c95 | ||
|
|
ab82302c95 | ||
|
|
d47be154ea | ||
|
|
35c892aea3 |
74
.github/workflows/release.yml
vendored
74
.github/workflows/release.yml
vendored
@@ -3,15 +3,14 @@ name: Release
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- 'v*'
|
- "v*"
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.14"
|
SIGN_PIPE_VER: "v0.0.14"
|
||||||
GORELEASER_VER: "v1.14.1"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||||
|
|
||||||
@@ -34,19 +33,16 @@ jobs:
|
|||||||
|
|
||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
-
|
- name: Checkout
|
||||||
name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
-
|
- name: Set up Go
|
||||||
name: Set up Go
|
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version: "1.23"
|
||||||
cache: false
|
cache: false
|
||||||
-
|
- name: Cache Go modules
|
||||||
name: Cache Go modules
|
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -55,24 +51,19 @@ jobs:
|
|||||||
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-go-releaser-
|
${{ runner.os }}-go-releaser-
|
||||||
-
|
- name: Install modules
|
||||||
name: Install modules
|
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
-
|
- name: check git status
|
||||||
name: check git status
|
|
||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
-
|
- name: Set up QEMU
|
||||||
name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v2
|
uses: docker/setup-qemu-action@v2
|
||||||
-
|
- name: Set up Docker Buildx
|
||||||
name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v2
|
||||||
-
|
- name: Login to Docker hub
|
||||||
name: Login to Docker hub
|
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v1
|
||||||
with:
|
with:
|
||||||
username: netbirdio
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
- name: Install OS build dependencies
|
- name: Install OS build dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||||
@@ -85,35 +76,31 @@ jobs:
|
|||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --rm-dist ${{ env.flags }}
|
args: release --clean ${{ env.flags }}
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
-
|
- name: upload non tags for debug purposes
|
||||||
name: upload non tags for debug purposes
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: dist/
|
path: dist/
|
||||||
retention-days: 3
|
retention-days: 3
|
||||||
-
|
- name: upload linux packages
|
||||||
name: upload linux packages
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-packages
|
name: linux-packages
|
||||||
path: dist/netbird_linux**
|
path: dist/netbird_linux**
|
||||||
retention-days: 3
|
retention-days: 3
|
||||||
-
|
- name: upload windows packages
|
||||||
name: upload windows packages
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-packages
|
name: windows-packages
|
||||||
path: dist/netbird_windows**
|
path: dist/netbird_windows**
|
||||||
retention-days: 3
|
retention-days: 3
|
||||||
-
|
- name: upload macos packages
|
||||||
name: upload macos packages
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: macos-packages
|
name: macos-packages
|
||||||
@@ -145,7 +132,7 @@ jobs:
|
|||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
~/.cache/go-build
|
~/.cache/go-build
|
||||||
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -169,7 +156,7 @@ jobs:
|
|||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
|
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
@@ -187,19 +174,16 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
-
|
- name: Checkout
|
||||||
name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
-
|
- name: Set up Go
|
||||||
name: Set up Go
|
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version: "1.23"
|
||||||
cache: false
|
cache: false
|
||||||
-
|
- name: Cache Go modules
|
||||||
name: Cache Go modules
|
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -208,23 +192,19 @@ jobs:
|
|||||||
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-ui-go-releaser-darwin-
|
${{ runner.os }}-ui-go-releaser-darwin-
|
||||||
-
|
- name: Install modules
|
||||||
name: Install modules
|
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
-
|
- name: check git status
|
||||||
name: check git status
|
|
||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
-
|
- name: Run GoReleaser
|
||||||
name: Run GoReleaser
|
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
|
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
-
|
- name: upload non tags for debug purposes
|
||||||
name: upload non tags for debug purposes
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release-ui-darwin
|
name: release-ui-darwin
|
||||||
@@ -233,7 +213,7 @@ jobs:
|
|||||||
|
|
||||||
trigger_signer:
|
trigger_signer:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [release,release_ui,release_ui_darwin]
|
needs: [release, release_ui, release_ui_darwin]
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger binaries sign pipelines
|
- name: Trigger binaries sign pipelines
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
version: 2
|
||||||
|
|
||||||
project_name: netbird
|
project_name: netbird
|
||||||
builds:
|
builds:
|
||||||
- id: netbird
|
- id: netbird
|
||||||
@@ -22,7 +24,7 @@ builds:
|
|||||||
goarch: 386
|
goarch: 386
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
tags:
|
tags:
|
||||||
- load_wgnt_from_rsrc
|
- load_wgnt_from_rsrc
|
||||||
|
|
||||||
@@ -42,19 +44,19 @@ builds:
|
|||||||
- softfloat
|
- softfloat
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
tags:
|
tags:
|
||||||
- load_wgnt_from_rsrc
|
- load_wgnt_from_rsrc
|
||||||
|
|
||||||
- id: netbird-mgmt
|
- id: netbird-mgmt
|
||||||
dir: management
|
dir: management
|
||||||
env:
|
env:
|
||||||
- CGO_ENABLED=1
|
- CGO_ENABLED=1
|
||||||
- >-
|
- >-
|
||||||
{{- if eq .Runtime.Goos "linux" }}
|
{{- if eq .Runtime.Goos "linux" }}
|
||||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
binary: netbird-mgmt
|
binary: netbird-mgmt
|
||||||
goos:
|
goos:
|
||||||
- linux
|
- linux
|
||||||
@@ -64,7 +66,7 @@ builds:
|
|||||||
- arm
|
- arm
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird-signal
|
- id: netbird-signal
|
||||||
dir: signal
|
dir: signal
|
||||||
@@ -78,7 +80,7 @@ builds:
|
|||||||
- arm
|
- arm
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird-relay
|
- id: netbird-relay
|
||||||
dir: relay
|
dir: relay
|
||||||
@@ -92,7 +94,7 @@ builds:
|
|||||||
- arm
|
- arm
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
archives:
|
archives:
|
||||||
- builds:
|
- builds:
|
||||||
@@ -100,7 +102,6 @@ archives:
|
|||||||
- netbird-static
|
- netbird-static
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
|
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client.
|
description: Netbird client.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
@@ -416,10 +417,9 @@ docker_manifests:
|
|||||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
|
|
||||||
brews:
|
brews:
|
||||||
-
|
- ids:
|
||||||
ids:
|
|
||||||
- default
|
- default
|
||||||
tap:
|
repository:
|
||||||
owner: netbirdio
|
owner: netbirdio
|
||||||
name: homebrew-tap
|
name: homebrew-tap
|
||||||
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
|
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
|
||||||
@@ -436,7 +436,7 @@ brews:
|
|||||||
uploads:
|
uploads:
|
||||||
- name: debian
|
- name: debian
|
||||||
ids:
|
ids:
|
||||||
- netbird-deb
|
- netbird-deb
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
version: 2
|
||||||
|
|
||||||
project_name: netbird-ui
|
project_name: netbird-ui
|
||||||
builds:
|
builds:
|
||||||
- id: netbird-ui
|
- id: netbird-ui
|
||||||
@@ -11,7 +13,7 @@ builds:
|
|||||||
- amd64
|
- amd64
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird-ui-windows
|
- id: netbird-ui-windows
|
||||||
dir: client/ui
|
dir: client/ui
|
||||||
@@ -26,7 +28,7 @@ builds:
|
|||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
- -H windowsgui
|
- -H windowsgui
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
archives:
|
archives:
|
||||||
- id: linux-arch
|
- id: linux-arch
|
||||||
@@ -39,7 +41,6 @@ archives:
|
|||||||
- netbird-ui-windows
|
- netbird-ui-windows
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
|
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client UI.
|
description: Netbird client UI.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
@@ -77,7 +78,7 @@ nfpms:
|
|||||||
uploads:
|
uploads:
|
||||||
- name: debian
|
- name: debian
|
||||||
ids:
|
ids:
|
||||||
- netbird-ui-deb
|
- netbird-ui-deb
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
version: 2
|
||||||
|
|
||||||
project_name: netbird-ui
|
project_name: netbird-ui
|
||||||
builds:
|
builds:
|
||||||
- id: netbird-ui-darwin
|
- id: netbird-ui-darwin
|
||||||
@@ -17,7 +19,7 @@ builds:
|
|||||||
- softfloat
|
- softfloat
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
tags:
|
tags:
|
||||||
- load_wgnt_from_rsrc
|
- load_wgnt_from_rsrc
|
||||||
|
|
||||||
@@ -28,4 +30,4 @@ archives:
|
|||||||
checksum:
|
checksum:
|
||||||
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
|
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
|
||||||
changelog:
|
changelog:
|
||||||
skip: true
|
disable: true
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
|
|||||||
|
|
||||||
**Goreleaser**
|
**Goreleaser**
|
||||||
```shell
|
```shell
|
||||||
goreleaser --snapshot --rm-dist
|
goreleaser build --snapshot --clean
|
||||||
```
|
```
|
||||||
**golangci-lint**
|
**golangci-lint**
|
||||||
```shell
|
```shell
|
||||||
|
|||||||
@@ -805,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
|||||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||||
|
|
||||||
for i, route := range peer.Routes {
|
for i, route := range peer.Routes {
|
||||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
srv, err := sig.NewServer(otel.Meter(""))
|
srv, err := sig.NewServer(context.Background(), otel.Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
sigProto.RegisterSignalExchangeServer(s, srv)
|
sigProto.RegisterSignalExchangeServer(s, srv)
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func formatError(es []error) string {
|
func formatError(es []error) string {
|
||||||
if len(es) == 0 {
|
if len(es) == 1 {
|
||||||
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
|
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
points := make([]string, len(es))
|
points := make([]string, len(es))
|
||||||
|
|||||||
@@ -295,6 +295,16 @@ func (c *ConnectClient) run(
|
|||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
c.statusRecorder.ClientTeardown()
|
c.statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
|
c.engineMutex.Lock()
|
||||||
|
engine := c.Engine()
|
||||||
|
if engine != nil {
|
||||||
|
err = engine.Stop()
|
||||||
|
}
|
||||||
|
c.engineMutex.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
backOff.Reset()
|
backOff.Reset()
|
||||||
|
|
||||||
log.Info("stopped NetBird client")
|
log.Info("stopped NetBird client")
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ func (e *Engine) Start() error {
|
|||||||
e.wgInterface = wgIface
|
e.wgInterface = wgIface
|
||||||
|
|
||||||
userspace := e.wgInterface.IsUserspaceBind()
|
userspace := e.wgInterface.IsUserspaceBind()
|
||||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
|
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
|
||||||
|
|
||||||
if e.config.RosenpassEnabled {
|
if e.config.RosenpassEnabled {
|
||||||
log.Infof("rosenpass is enabled")
|
log.Infof("rosenpass is enabled")
|
||||||
|
|||||||
@@ -1056,7 +1056,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
|||||||
log.Fatalf("failed to listen: %v", err)
|
log.Fatalf("failed to listen: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
srv, err := signalServer.NewServer(otel.Meter(""))
|
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
proto.RegisterSignalExchangeServer(s, srv)
|
proto.RegisterSignalExchangeServer(s, srv)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
|||||||
defer func() {
|
defer func() {
|
||||||
err := unix.Close(fd)
|
err := unix.Close(fd)
|
||||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||||
log.Errorf("Network monitor: failed to close routing socket: %v", err)
|
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
|||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
err := unix.Close(fd)
|
err := unix.Close(fd)
|
||||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||||
log.Debugf("Network monitor: closed routing socket")
|
log.Debugf("Network monitor: closed routing socket: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
|||||||
n, err := unix.Read(fd, buf)
|
n, err := unix.Read(fd, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||||
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
|
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if n < unix.SizeofRtMsghdr {
|
if n < unix.SizeofRtMsghdr {
|
||||||
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
|||||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||||
route, err := parseRouteMessage(buf[:n])
|
route, err := parseRouteMessage(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Network monitor: error parsing routing message: %v", err)
|
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -527,8 +527,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.log.Debugf("Relay connection is ready to use")
|
conn.log.Debugf("Relay connection is ready to use")
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.Set(StatusConnected)
|
||||||
|
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||||
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
|
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
@@ -775,8 +775,8 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
|
|||||||
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
|
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
|
||||||
}
|
}
|
||||||
conn.log.Debugf("setup ice turn connection")
|
conn.log.Debugf("setup ice turn connection")
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||||
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
|
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
if errClose := wgProxy.CloseConn(); errClose != nil {
|
if errClose := wgProxy.CloseConn(); errClose != nil {
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_GetKey(t *testing.T) {
|
func TestConn_GetKey(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wgProxyFactory.Free()
|
_ = wgProxyFactory.Free()
|
||||||
}()
|
}()
|
||||||
@@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wgProxyFactory.Free()
|
_ = wgProxyFactory.Free()
|
||||||
}()
|
}()
|
||||||
@@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wgProxyFactory.Free()
|
_ = wgProxyFactory.Free()
|
||||||
}()
|
}()
|
||||||
@@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
func TestConn_Status(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wgProxyFactory.Free()
|
_ = wgProxyFactory.Free()
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package wgproxy
|
package ebpf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package wgproxy
|
package ebpf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android
|
||||||
|
|
||||||
package wgproxy
|
package ebpf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -13,47 +13,49 @@ import (
|
|||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
loopbackAddr = "127.0.0.1"
|
||||||
|
)
|
||||||
|
|
||||||
// WGEBPFProxy definition for proxy with EBPF support
|
// WGEBPFProxy definition for proxy with EBPF support
|
||||||
type WGEBPFProxy struct {
|
type WGEBPFProxy struct {
|
||||||
ebpfManager ebpfMgr.Manager
|
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
lastUsedPort uint16
|
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
|
|
||||||
|
ebpfManager ebpfMgr.Manager
|
||||||
turnConnStore map[uint16]net.Conn
|
turnConnStore map[uint16]net.Conn
|
||||||
turnConnMutex sync.Mutex
|
turnConnMutex sync.Mutex
|
||||||
|
|
||||||
rawConn net.PacketConn
|
lastUsedPort uint16
|
||||||
conn transport.UDPConn
|
rawConn net.PacketConn
|
||||||
|
conn transport.UDPConn
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGEBPFProxy create new WGEBPFProxy instance
|
// NewWGEBPFProxy create new WGEBPFProxy instance
|
||||||
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
|
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
||||||
log.Debugf("instantiate ebpf proxy")
|
log.Debugf("instantiate ebpf proxy")
|
||||||
wgProxy := &WGEBPFProxy{
|
wgProxy := &WGEBPFProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
ebpfManager: ebpf.GetEbpfManagerInstance(),
|
ebpfManager: ebpf.GetEbpfManagerInstance(),
|
||||||
lastUsedPort: 0,
|
|
||||||
turnConnStore: make(map[uint16]net.Conn),
|
turnConnStore: make(map[uint16]net.Conn),
|
||||||
}
|
}
|
||||||
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
|
|
||||||
|
|
||||||
return wgProxy
|
return wgProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen load ebpf program and listen the proxy
|
// Listen load ebpf program and listen the proxy
|
||||||
func (p *WGEBPFProxy) listen() error {
|
func (p *WGEBPFProxy) Listen() error {
|
||||||
pl := portLookup{}
|
pl := portLookup{}
|
||||||
wgPorxyPort, err := pl.searchFreePort()
|
wgPorxyPort, err := pl.searchFreePort()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -72,13 +74,14 @@ func (p *WGEBPFProxy) listen() error {
|
|||||||
|
|
||||||
addr := net.UDPAddr{
|
addr := net.UDPAddr{
|
||||||
Port: wgPorxyPort,
|
Port: wgPorxyPort,
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
IP: net.ParseIP(loopbackAddr),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
conn, err := nbnet.ListenUDP("udp", &addr)
|
conn, err := nbnet.ListenUDP("udp", &addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cErr := p.Free()
|
if cErr := p.Free(); cErr != nil {
|
||||||
if cErr != nil {
|
|
||||||
log.Errorf("Failed to close the wgproxy: %s", cErr)
|
log.Errorf("Failed to close the wgproxy: %s", cErr)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -91,108 +94,114 @@ func (p *WGEBPFProxy) listen() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn add new turn connection for the proxy
|
// AddTurnConn add new turn connection for the proxy
|
||||||
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
|
||||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go p.proxyToLocal(wgEndpointPort, turnConn)
|
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
|
||||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||||
|
|
||||||
wgEndpoint := &net.UDPAddr{
|
wgEndpoint := &net.UDPAddr{
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
IP: net.ParseIP(loopbackAddr),
|
||||||
Port: int(wgEndpointPort),
|
Port: int(wgEndpointPort),
|
||||||
}
|
}
|
||||||
return wgEndpoint, nil
|
return wgEndpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn doing nothing because this type of proxy implementation does not store the connection
|
// Free resources except the remoteConns will be keep open.
|
||||||
func (p *WGEBPFProxy) CloseConn() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free resources
|
|
||||||
func (p *WGEBPFProxy) Free() error {
|
func (p *WGEBPFProxy) Free() error {
|
||||||
log.Debugf("free up ebpf wg proxy")
|
log.Debugf("free up ebpf wg proxy")
|
||||||
var err1, err2, err3 error
|
if p.ctx != nil && p.ctx.Err() != nil {
|
||||||
if p.conn != nil {
|
//nolint
|
||||||
err1 = p.conn.Close()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err2 = p.ebpfManager.FreeWGProxy()
|
p.ctxCancel()
|
||||||
if p.rawConn != nil {
|
|
||||||
err3 = p.rawConn.Close()
|
var result *multierror.Error
|
||||||
|
if p.conn != nil { // p.conn will be nil if we have failed to listen
|
||||||
|
if err := p.conn.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err1 != nil {
|
if err := p.ebpfManager.FreeWGProxy(); err != nil {
|
||||||
return err1
|
result = multierror.Append(result, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err2 != nil {
|
if err := p.rawConn.Close(); err != nil {
|
||||||
return err2
|
result = multierror.Append(result, err)
|
||||||
}
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
return err3
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
|
||||||
|
defer p.removeTurnConn(endpointPort)
|
||||||
|
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
n int
|
||||||
|
)
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
var err error
|
for ctx.Err() == nil {
|
||||||
defer func() {
|
n, err = remoteConn.Read(buf)
|
||||||
p.removeTurnConn(endpointPort)
|
if err != nil {
|
||||||
}()
|
if ctx.Err() != nil {
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
var n int
|
|
||||||
n, err = remoteConn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
if err != io.EOF {
|
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = p.sendPkg(buf[:n], endpointPort)
|
if err != io.EOF {
|
||||||
if err != nil {
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
|
||||||
|
if ctx.Err() != nil || p.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||||
|
// From this go routine has only one instance.
|
||||||
func (p *WGEBPFProxy) proxyToRemote() {
|
func (p *WGEBPFProxy) proxyToRemote() {
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for p.ctx.Err() == nil {
|
||||||
select {
|
if err := p.readAndForwardPacket(buf); err != nil {
|
||||||
case <-p.ctx.Done():
|
if p.ctx.Err() != nil {
|
||||||
return
|
|
||||||
default:
|
|
||||||
n, addr, err := p.conn.ReadFromUDP(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to read UDP pkg from WG: %s", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.Errorf("failed to proxy packet to remote conn: %s", err)
|
||||||
p.turnConnMutex.Lock()
|
|
||||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
|
||||||
p.turnConnMutex.Unlock()
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.Write(buf[:n])
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
|
||||||
|
n, addr, err := p.conn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read UDP packet from WG: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.turnConnMutex.Lock()
|
||||||
|
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||||
|
p.turnConnMutex.Unlock()
|
||||||
|
if !ok {
|
||||||
|
if p.ctx.Err() == nil {
|
||||||
|
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write(buf[:n]); err != nil {
|
||||||
|
return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
||||||
p.turnConnMutex.Lock()
|
p.turnConnMutex.Lock()
|
||||||
defer p.turnConnMutex.Unlock()
|
defer p.turnConnMutex.Unlock()
|
||||||
@@ -206,11 +215,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
|
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
|
||||||
log.Debugf("remove turn conn from store by port: %d", turnConnID)
|
|
||||||
p.turnConnMutex.Lock()
|
p.turnConnMutex.Lock()
|
||||||
defer p.turnConnMutex.Unlock()
|
defer p.turnConnMutex.Unlock()
|
||||||
delete(p.turnConnStore, turnConnID)
|
|
||||||
|
|
||||||
|
_, ok := p.turnConnStore[turnConnID]
|
||||||
|
if ok {
|
||||||
|
log.Debugf("remove turn conn from store by port: %d", turnConnID)
|
||||||
|
}
|
||||||
|
delete(p.turnConnStore, turnConnID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
|
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
|
||||||
@@ -1,14 +1,13 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android
|
||||||
|
|
||||||
package wgproxy
|
package ebpf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWGEBPFProxy_connStore(t *testing.T) {
|
func TestWGEBPFProxy_connStore(t *testing.T) {
|
||||||
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
wgProxy := NewWGEBPFProxy(1)
|
||||||
|
|
||||||
p, _ := wgProxy.storeTurnConn(nil)
|
p, _ := wgProxy.storeTurnConn(nil)
|
||||||
if p != 1 {
|
if p != 1 {
|
||||||
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
||||||
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
wgProxy := NewWGEBPFProxy(1)
|
||||||
|
|
||||||
_, _ = wgProxy.storeTurnConn(nil)
|
_, _ = wgProxy.storeTurnConn(nil)
|
||||||
wgProxy.lastUsedPort = 65535
|
wgProxy.lastUsedPort = 65535
|
||||||
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
|
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
|
||||||
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
wgProxy := NewWGEBPFProxy(1)
|
||||||
|
|
||||||
for i := 0; i < 65535; i++ {
|
for i := 0; i < 65535; i++ {
|
||||||
_, _ = wgProxy.storeTurnConn(nil)
|
_, _ = wgProxy.storeTurnConn(nil)
|
||||||
44
client/internal/wgproxy/ebpf/wrapper.go
Normal file
44
client/internal/wgproxy/ebpf/wrapper.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package ebpf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
|
type ProxyWrapper struct {
|
||||||
|
WgeBPFProxy *WGEBPFProxy
|
||||||
|
|
||||||
|
remoteConn net.Conn
|
||||||
|
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
||||||
|
ctxConn, cancel := context.WithCancel(ctx)
|
||||||
|
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
return nil, fmt.Errorf("add turn conn: %w", err)
|
||||||
|
}
|
||||||
|
e.remoteConn = remoteConn
|
||||||
|
e.cancel = cancel
|
||||||
|
return addr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||||
|
func (e *ProxyWrapper) CloseConn() error {
|
||||||
|
if e.cancel == nil {
|
||||||
|
return fmt.Errorf("proxy not started")
|
||||||
|
}
|
||||||
|
|
||||||
|
e.cancel()
|
||||||
|
|
||||||
|
if err := e.remoteConn.Close(); err != nil {
|
||||||
|
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
package wgproxy
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
type Factory struct {
|
|
||||||
wgPort int
|
|
||||||
ebpfProxy Proxy
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Factory) GetProxy(ctx context.Context) Proxy {
|
|
||||||
if w.ebpfProxy != nil {
|
|
||||||
return w.ebpfProxy
|
|
||||||
}
|
|
||||||
return NewWGUserSpaceProxy(ctx, w.wgPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Factory) Free() error {
|
|
||||||
if w.ebpfProxy != nil {
|
|
||||||
return w.ebpfProxy.Free()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -3,20 +3,26 @@
|
|||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
|
type Factory struct {
|
||||||
|
wgPort int
|
||||||
|
ebpfProxy *ebpf.WGEBPFProxy
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFactory(userspace bool, wgPort int) *Factory {
|
||||||
f := &Factory{wgPort: wgPort}
|
f := &Factory{wgPort: wgPort}
|
||||||
|
|
||||||
if userspace {
|
if userspace {
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
|
||||||
err := ebpfProxy.listen()
|
err := ebpfProxy.Listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||||
return f
|
return f
|
||||||
@@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
|
|||||||
f.ebpfProxy = ebpfProxy
|
f.ebpfProxy = ebpfProxy
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Factory) GetProxy() Proxy {
|
||||||
|
if w.ebpfProxy != nil {
|
||||||
|
p := &ebpf.ProxyWrapper{
|
||||||
|
WgeBPFProxy: w.ebpfProxy,
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
return usp.NewWGUserSpaceProxy(w.wgPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Factory) Free() error {
|
||||||
|
if w.ebpfProxy == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.ebpfProxy.Free()
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,8 +2,20 @@
|
|||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import "context"
|
import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||||
|
|
||||||
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
|
type Factory struct {
|
||||||
|
wgPort int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFactory(_ bool, wgPort int) *Factory {
|
||||||
return &Factory{wgPort: wgPort}
|
return &Factory{wgPort: wgPort}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Factory) GetProxy() Proxy {
|
||||||
|
return usp.NewWGUserSpaceProxy(w.wgPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Factory) Free() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Proxy is a transfer layer between the Turn connection and the WireGuard
|
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||||
type Proxy interface {
|
type Proxy interface {
|
||||||
AddTurnConn(turnConn net.Conn) (net.Addr, error)
|
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
Free() error
|
|
||||||
}
|
}
|
||||||
|
|||||||
128
client/internal/wgproxy/proxy_test.go
Normal file
128
client/internal/wgproxy/proxy_test.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
_ = util.InitLog("trace", "console")
|
||||||
|
code := m.Run()
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mocConn struct {
|
||||||
|
closeChan chan struct{}
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockConn() *mocConn {
|
||||||
|
return &mocConn{
|
||||||
|
closeChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mocConn) Read(b []byte) (n int, err error) {
|
||||||
|
<-m.closeChan
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mocConn) Write(b []byte) (n int, err error) {
|
||||||
|
<-m.closeChan
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mocConn) Close() error {
|
||||||
|
if m.closed == true {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.closed = true
|
||||||
|
close(m.closeChan)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mocConn) LocalAddr() net.Addr {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mocConn) RemoteAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("172.16.254.1"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mocConn) SetDeadline(t time.Time) error {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mocConn) SetReadDeadline(t time.Time) error {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
proxy Proxy
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "userspace proxy",
|
||||||
|
proxy: usp.NewWGUserSpaceProxy(51830),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
proxyWrapper := &ebpf.ProxyWrapper{
|
||||||
|
WgeBPFProxy: ebpfProxy,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests = append(tests, struct {
|
||||||
|
name string
|
||||||
|
proxy Proxy
|
||||||
|
}{
|
||||||
|
name: "ebpf proxy",
|
||||||
|
proxy: proxyWrapper,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
relayedConn := newMockConn()
|
||||||
|
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = relayedConn.Close()
|
||||||
|
if err := tt.proxy.CloseConn(); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
package wgproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WGUserSpaceProxy proxies
|
|
||||||
type WGUserSpaceProxy struct {
|
|
||||||
localWGListenPort int
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
remoteConn net.Conn
|
|
||||||
localConn net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
|
|
||||||
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
|
|
||||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
|
||||||
p := &WGUserSpaceProxy{
|
|
||||||
localWGListenPort: wgPort,
|
|
||||||
}
|
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddTurnConn start the proxy with the given remote conn
|
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
|
||||||
p.remoteConn = remoteConn
|
|
||||||
|
|
||||||
var err error
|
|
||||||
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
go p.proxyToRemote()
|
|
||||||
go p.proxyToLocal()
|
|
||||||
|
|
||||||
return p.localConn.LocalAddr(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloseConn close the localConn
|
|
||||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
|
||||||
p.cancel()
|
|
||||||
if p.localConn == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.remoteConn == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := p.remoteConn.Close(); err != nil {
|
|
||||||
log.Warnf("failed to close remote conn: %s", err)
|
|
||||||
}
|
|
||||||
return p.localConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free doing nothing because this implementation of proxy does not have global state
|
|
||||||
func (p *WGUserSpaceProxy) Free() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
|
||||||
// blocks
|
|
||||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
|
||||||
defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
|
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
n, err := p.localConn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = p.remoteConn.Write(buf[:n])
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
p.cancel()
|
|
||||||
} else {
|
|
||||||
log.Debugf("failed to write to remote conn: %s", err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
|
||||||
// blocks
|
|
||||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
|
||||||
defer p.cancel()
|
|
||||||
defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
n, err := p.remoteConn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Errorf("failed to read from remote conn: %s", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = p.localConn.Write(buf[:n])
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
146
client/internal/wgproxy/usp/proxy.go
Normal file
146
client/internal/wgproxy/usp/proxy.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package usp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WGUserSpaceProxy proxies
|
||||||
|
type WGUserSpaceProxy struct {
|
||||||
|
localWGListenPort int
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
remoteConn net.Conn
|
||||||
|
localConn net.Conn
|
||||||
|
closeMu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||||
|
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||||
|
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||||
|
p := &WGUserSpaceProxy{
|
||||||
|
localWGListenPort: wgPort,
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTurnConn start the proxy with the given remote conn
|
||||||
|
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
||||||
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
|
|
||||||
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
|
var err error
|
||||||
|
dialer := net.Dialer{}
|
||||||
|
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go p.proxyToRemote()
|
||||||
|
go p.proxyToLocal()
|
||||||
|
|
||||||
|
return p.localConn.LocalAddr(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseConn close the localConn
|
||||||
|
func (p *WGUserSpaceProxy) CloseConn() error {
|
||||||
|
if p.cancel == nil {
|
||||||
|
return fmt.Errorf("proxy not started")
|
||||||
|
}
|
||||||
|
return p.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WGUserSpaceProxy) close() error {
|
||||||
|
p.closeMu.Lock()
|
||||||
|
defer p.closeMu.Unlock()
|
||||||
|
|
||||||
|
// prevent double close
|
||||||
|
if p.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.closed = true
|
||||||
|
|
||||||
|
p.cancel()
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
if err := p.remoteConn.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.localConn.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||||
|
}
|
||||||
|
return errors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||||
|
func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||||
|
defer func() {
|
||||||
|
if err := p.close(); err != nil {
|
||||||
|
log.Warnf("error in proxy to remote loop: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
for p.ctx.Err() == nil {
|
||||||
|
n, err := p.localConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if p.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = p.remoteConn.Write(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
if p.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("failed to write to remote conn: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||||
|
func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||||
|
defer func() {
|
||||||
|
if err := p.close(); err != nil {
|
||||||
|
log.Warnf("error in proxy to local loop: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
for p.ctx.Err() == nil {
|
||||||
|
n, err := p.remoteConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if p.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = p.localConn.Write(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
if p.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -160,7 +160,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
|||||||
log.Fatalf("failed to listen: %v", err)
|
log.Fatalf("failed to listen: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
srv, err := signalServer.NewServer(otel.Meter(""))
|
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
proto.RegisterSignalExchangeServer(s, srv)
|
proto.RegisterSignalExchangeServer(s, srv)
|
||||||
|
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -59,8 +59,8 @@ require (
|
|||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -521,12 +521,12 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||||
|
|||||||
@@ -20,11 +20,6 @@ import (
|
|||||||
cacheStore "github.com/eko/gocache/v3/store"
|
cacheStore "github.com/eko/gocache/v3/store"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
gocache "github.com/patrickmn/go-cache"
|
|
||||||
"github.com/rs/xid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/base62"
|
"github.com/netbirdio/netbird/base62"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
@@ -41,6 +36,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration {
|
|||||||
|
|
||||||
type AccountManager interface {
|
type AccountManager interface {
|
||||||
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
|
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
|
||||||
|
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||||
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||||
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
|
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
|
||||||
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||||
@@ -75,12 +75,14 @@ type AccountManager interface {
|
|||||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
||||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||||
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
|
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
||||||
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
|
||||||
|
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||||
|
GetUserByID(ctx context.Context, id string) (*User, error)
|
||||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||||
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@@ -107,7 +109,7 @@ type AccountManager interface {
|
|||||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error
|
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
||||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
@@ -145,6 +147,7 @@ type AccountManager interface {
|
|||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||||
|
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
@@ -268,6 +271,11 @@ type AccountNetwork struct {
|
|||||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountDNSSettings used in gorm to only load dns settings and not whole account
|
||||||
|
type AccountDNSSettings struct {
|
||||||
|
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||||
|
}
|
||||||
|
|
||||||
type UserPermissions struct {
|
type UserPermissions struct {
|
||||||
DashboardView string `json:"dashboard_view"`
|
DashboardView string `json:"dashboard_view"`
|
||||||
}
|
}
|
||||||
@@ -1252,25 +1260,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
|
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
|
||||||
// userID doesn't have an account associated with it, one account is created
|
// If an accountID is provided, it checks if the account exists and returns it.
|
||||||
// domain is used to create a new account if no account is found
|
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
|
||||||
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) {
|
// If the user doesn't have an account, it creates one using the provided domain.
|
||||||
|
// Returns the account ID or an error if none is found or created.
|
||||||
|
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
||||||
if accountID != "" {
|
if accountID != "" {
|
||||||
return am.Store.GetAccount(ctx, accountID)
|
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||||
} else if userID != "" {
|
|
||||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
|
return "", err
|
||||||
}
|
}
|
||||||
err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
|
if !exists {
|
||||||
if err != nil {
|
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return account, nil
|
return accountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
|
if userID != "" {
|
||||||
|
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||||
|
if err != nil {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.Id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNil(i idp.Manager) bool {
|
func isNil(i idp.Manager) bool {
|
||||||
@@ -1613,13 +1633,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
|
|||||||
}
|
}
|
||||||
|
|
||||||
// redeemInvite checks whether user has been invited and redeems the invite
|
// redeemInvite checks whether user has been invited and redeems the invite
|
||||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
|
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
|
||||||
// only possible with the enabled IdP manager
|
// only possible with the enabled IdP manager
|
||||||
if am.idpManager == nil {
|
if am.idpManager == nil {
|
||||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1678,6 +1703,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
|
|||||||
return am.Store.SaveAccount(ctx, account)
|
return am.Store.SaveAccount(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccount returns an account associated with this account ID.
|
||||||
|
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
|
||||||
|
return am.Store.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
// GetAccountFromPAT returns Account and User associated with a personal access token
|
||||||
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
|
||||||
if len(token) != PATLength {
|
if len(token) != PATLength {
|
||||||
@@ -1726,10 +1756,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
|
|||||||
return account, user, pat, nil
|
return account, user, pat, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromToken returns an account associated with this token
|
// GetAccountByID returns an account associated with this account ID.
|
||||||
func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
|
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||||
|
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return nil, nil, fmt.Errorf("user ID is empty")
|
return "", "", fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||||
// This section is mostly related to self-hosted installations.
|
// This section is mostly related to self-hosted installations.
|
||||||
@@ -1739,110 +1783,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
|||||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return "", "", err
|
||||||
}
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
|
|
||||||
alreadyUnlocked := false
|
|
||||||
defer func() {
|
|
||||||
if !alreadyUnlocked {
|
|
||||||
unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, newAcc.Id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
user := account.Users[claims.UserId]
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||||
if user == nil {
|
if err != nil {
|
||||||
// this is not really possible because we got an account by user ID
|
// this is not really possible because we got an account by user ID
|
||||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsServiceUser && claims.Invited {
|
if !user.IsServiceUser && claims.Invited {
|
||||||
err = am.redeemInvite(ctx, account, claims.UserId)
|
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return "", "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.JWTGroupsEnabled {
|
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
|
||||||
if account.Settings.JWTGroupsClaimName == "" {
|
return "", "", err
|
||||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
|
||||||
return account, user, nil
|
|
||||||
}
|
|
||||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
|
||||||
if slice, ok := claim.([]interface{}); ok {
|
|
||||||
var groupsNames []string
|
|
||||||
for _, item := range slice {
|
|
||||||
if g, ok := item.(string); ok {
|
|
||||||
groupsNames = append(groupsNames, g)
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
oldGroups := make([]string, len(user.AutoGroups))
|
|
||||||
copy(oldGroups, user.AutoGroups)
|
|
||||||
// if groups were added or modified, save the account
|
|
||||||
if account.SetJWTGroups(claims.UserId, groupsNames) {
|
|
||||||
if account.Settings.GroupsPropagationEnabled {
|
|
||||||
if user, err := account.FindUser(claims.UserId); err == nil {
|
|
||||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
|
||||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
|
||||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
|
||||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
unlock()
|
|
||||||
alreadyUnlocked = true
|
|
||||||
for _, g := range addNewGroups {
|
|
||||||
if group := account.GetGroup(g); group != nil {
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
|
||||||
map[string]any{
|
|
||||||
"group": group.Name,
|
|
||||||
"group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser,
|
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, g := range removeOldGroups {
|
|
||||||
if group := account.GetGroup(g); group != nil {
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
|
||||||
map[string]any{
|
|
||||||
"group": group.Name,
|
|
||||||
"group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser,
|
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, user, nil
|
return accountID, user.Id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
|
// and propagates changes to peers if group propagation is enabled.
|
||||||
|
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings == nil || !settings.JWTGroupsEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.JWTGroupsClaimName == "" {
|
||||||
|
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Remove GetAccount after refactoring account peer's update
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
|
|
||||||
|
oldGroups := make([]string, len(user.AutoGroups))
|
||||||
|
copy(oldGroups, user.AutoGroups)
|
||||||
|
|
||||||
|
// Update the account if group membership changes
|
||||||
|
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
||||||
|
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||||
|
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||||
|
|
||||||
|
if settings.GroupsPropagationEnabled {
|
||||||
|
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
||||||
|
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
||||||
|
account.Network.IncSerial()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Propagate changes to peers if group propagation is enabled
|
||||||
|
if settings.GroupsPropagationEnabled {
|
||||||
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range addNewGroups {
|
||||||
|
if group := account.GetGroup(g); group != nil {
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||||
|
map[string]any{
|
||||||
|
"group": group.Name,
|
||||||
|
"group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser,
|
||||||
|
"user_name": user.ServiceUserName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range removeOldGroups {
|
||||||
|
if group := account.GetGroup(g); group != nil {
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
||||||
|
map[string]any{
|
||||||
|
"group": group.Name,
|
||||||
|
"group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser,
|
||||||
|
"user_name": user.ServiceUserName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||||
// if domain is of the PrivateCategory category, it will evaluate
|
// if domain is of the PrivateCategory category, it will evaluate
|
||||||
// if account is new, existing or if there is another account with the same domain
|
// if account is new, existing or if there is another account with the same domain
|
||||||
//
|
//
|
||||||
@@ -1859,26 +1904,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
|||||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||||
//
|
//
|
||||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return nil, fmt.Errorf("user ID is empty")
|
return "", fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if Account ID is part of the claims
|
// if Account ID is part of the claims
|
||||||
// it means that we've already classified the domain and user has an account
|
// it means that we've already classified the domain and user has an account
|
||||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||||
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||||
} else if claims.AccountId != "" {
|
} else if claims.AccountId != "" {
|
||||||
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
|
||||||
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
if userAccountID != claims.AccountId {
|
||||||
|
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||||
}
|
}
|
||||||
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
|
||||||
return accountFromID, nil
|
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
||||||
|
return userAccountID, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1888,48 +1941,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
|||||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||||
|
|
||||||
// We checked if the domain has a primary account already
|
// We checked if the domain has a primary account already
|
||||||
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// if NotFound we are good to continue, otherwise return error
|
// if NotFound we are good to continue, otherwise return error
|
||||||
e, ok := status.FromError(err)
|
e, ok := status.FromError(err)
|
||||||
if !ok || e.Type() != status.NotFound {
|
if !ok || e.Type() != status.NotFound {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||||
// and peers that shouldn't be lost.
|
// and peers that shouldn't be lost.
|
||||||
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
|
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
||||||
|
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
||||||
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
return "", err
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return account, nil
|
|
||||||
|
return account.Id, nil
|
||||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
if domainAccount != nil {
|
var domainAccount *Account
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
if domainAccountID != "" {
|
||||||
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return am.handleNewUserAccount(ctx, domainAccount, claims)
|
|
||||||
|
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return account.Id, nil
|
||||||
} else {
|
} else {
|
||||||
// other error
|
// other error
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2022,26 +2080,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
|||||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||||
// group propagation and set the list of groups with access permissions.
|
// group propagation and set the list of groups with access permissions.
|
||||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensures JWT group synchronization to the management is enabled before,
|
// Ensures JWT group synchronization to the management is enabled before,
|
||||||
// filtering access based on the allowed groups.
|
// filtering access based on the allowed groups.
|
||||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
if settings != nil && settings.JWTGroupsEnabled {
|
||||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||||
userJWTGroups := make([]string, 0)
|
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
|
|
||||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
|
||||||
if claimGroups, ok := claim.([]interface{}); ok {
|
|
||||||
for _, g := range claimGroups {
|
|
||||||
if group, ok := g.(string); ok {
|
|
||||||
userJWTGroups = append(userJWTGroups, group)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||||
@@ -2111,6 +2164,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor
|
|||||||
return newLabel, nil
|
return newLabel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// addAllGroup to account object if it doesn't exist
|
// addAllGroup to account object if it doesn't exist
|
||||||
func addAllGroup(account *Account) error {
|
func addAllGroup(account *Account) error {
|
||||||
if len(account.Groups) == 0 {
|
if len(account.Groups) == 0 {
|
||||||
@@ -2193,6 +2259,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
|||||||
return acc
|
return acc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractJWTGroups extracts the group names from a JWT token's claims.
|
||||||
|
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
|
||||||
|
userJWTGroups := make([]string, 0)
|
||||||
|
|
||||||
|
if claim, ok := claims.Raw[claimName]; ok {
|
||||||
|
if claimGroups, ok := claim.([]interface{}); ok {
|
||||||
|
for _, g := range claimGroups {
|
||||||
|
if group, ok := g.(string); ok {
|
||||||
|
userJWTGroups = append(userJWTGroups, group)
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return userJWTGroups
|
||||||
|
}
|
||||||
|
|
||||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||||
for _, userGroup := range userGroups {
|
for _, userGroup := range userGroups {
|
||||||
|
|||||||
@@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
|||||||
assert.Equal(t, account.Id, ev.TargetID)
|
assert.Equal(t, account.Id, ev.TargetID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||||
type initUserParams jwtclaims.AuthorizationClaims
|
type initUserParams jwtclaims.AuthorizationClaims
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
@@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||||
require.NoError(t, err, "create init user failed")
|
require.NoError(t, err, "create init user failed")
|
||||||
|
|
||||||
|
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get init account failed")
|
||||||
|
|
||||||
if testCase.inputUpdateAttrs {
|
if testCase.inputUpdateAttrs {
|
||||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||||
require.NoError(t, err, "update init user failed")
|
require.NoError(t, err, "update init user failed")
|
||||||
@@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
|||||||
testCase.inputClaims.AccountId = initAccount.Id
|
testCase.inputClaims.AccountId = initAccount.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
|
||||||
require.NoError(t, err, "support function failed")
|
require.NoError(t, err, "support function failed")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get account failed")
|
||||||
|
|
||||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||||
|
|
||||||
@@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID := initAccount.Id
|
accountID := initAccount.Id
|
||||||
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
|
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
|
||||||
require.NoError(t, err, "create init user failed")
|
require.NoError(t, err, "create init user failed")
|
||||||
// as initAccount was created without account id we have to take the id after account initialization
|
// as initAccount was created without account id we have to take the id after account initialization
|
||||||
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
|
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
|
||||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||||
initAccount = acc
|
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get init account failed")
|
||||||
|
|
||||||
claims := jwtclaims.AuthorizationClaims{
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
||||||
@@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||||
require.NoError(t, err, "get account by token failed")
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get account failed")
|
||||||
|
|
||||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "save account failed")
|
require.NoError(t, err, "save account failed")
|
||||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||||
|
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||||
require.NoError(t, err, "get account by token failed")
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get account failed")
|
||||||
|
|
||||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "save account failed")
|
require.NoError(t, err, "save account failed")
|
||||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||||
|
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||||
require.NoError(t, err, "get account by token failed")
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "get account failed")
|
||||||
|
|
||||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||||
|
|
||||||
groupsByNames := map[string]*group.Group{}
|
groupsByNames := map[string]*group.Group{}
|
||||||
@@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
|||||||
|
|
||||||
userId := "test_user"
|
userId := "test_user"
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if account == nil {
|
if accountID == "" {
|
||||||
t.Fatalf("expected to create an account for a user %s", userId)
|
t.Fatalf("expected to create an account for a user %s", userId)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
|
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
|
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected an error when user and account IDs are empty")
|
t.Errorf("expected an error when user and account IDs are empty")
|
||||||
}
|
}
|
||||||
@@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
|
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||||
t.Errorf("delete default rule: %v", err)
|
t.Errorf("delete default rule: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
assert.NotNil(t, account.Settings)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
|
require.NoError(t, err, "unable to get account settings")
|
||||||
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
|
|
||||||
|
assert.NotNil(t, settings)
|
||||||
|
assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
|
||||||
|
assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
|
||||||
|
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Hour,
|
PeerLoginExpiration: time.Hour,
|
||||||
PeerLoginExpirationEnabled: true,
|
PeerLoginExpirationEnabled: true,
|
||||||
})
|
})
|
||||||
@@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Hour,
|
PeerLoginExpiration: time.Hour,
|
||||||
PeerLoginExpirationEnabled: true,
|
PeerLoginExpirationEnabled: true,
|
||||||
})
|
})
|
||||||
@@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
@@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||||
require.NoError(t, err, "unable to mark peer connected")
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
@@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Hour,
|
PeerLoginExpiration: time.Hour,
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
})
|
})
|
||||||
@@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||||
|
|
||||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
||||||
require.NoError(t, err, "unable to get account by ID")
|
require.NoError(t, err, "unable to get account by ID")
|
||||||
|
|
||||||
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
|
require.NoError(t, err, "unable to get account settings")
|
||||||
|
|
||||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||||
|
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
|
||||||
|
|
||||||
|
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Second,
|
PeerLoginExpiration: time.Second,
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
})
|
})
|
||||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||||
|
|
||||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings {
|
|||||||
|
|
||||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
||||||
}
|
}
|
||||||
dnsSettings := account.DNSSettings.Copy()
|
|
||||||
return &dnsSettings, nil
|
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||||
|
|||||||
@@ -10,14 +10,15 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/netbirdio/netbird/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
@@ -634,10 +635,19 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.Users[userID].Copy(), nil
|
user := account.Users[userID].Copy()
|
||||||
|
pat := make([]PersonalAccessToken, 0, len(user.PATs))
|
||||||
|
for _, token := range user.PATs {
|
||||||
|
if token != nil {
|
||||||
|
pat = append(pat, *token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
user.PATsG = pat
|
||||||
|
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||||
account, err := s.getAccount(accountID)
|
account, err := s.getAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -931,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID strin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
|
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -950,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
|
|||||||
return FileStoreEngine
|
return FileStoreEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
|
func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
|
||||||
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
|
||||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
|
||||||
|
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.Domain, account.DomainCategory, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountExists checks whether an account exists by the given ID.
|
||||||
|
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
|
||||||
|
_, exists := s.Accounts[id]
|
||||||
|
return exists, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string {
|
|||||||
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
|
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroup object of the peers
|
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||||
|
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
|
||||||
|
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroup returns a specific group by groupID in an account
|
||||||
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
|
||||||
}
|
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
|
||||||
if ok {
|
|
||||||
return group, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllGroups returns all groups in an account
|
// GetAllGroups returns all groups in an account
|
||||||
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
return am.Store.GetAccountGroups(ctx, accountID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
|
||||||
}
|
|
||||||
|
|
||||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
|
||||||
for _, item := range account.Groups {
|
|
||||||
groups = append(groups, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
return groups, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
matchingGroups := make([]*nbgroup.Group, 0)
|
|
||||||
for _, group := range account.Groups {
|
|
||||||
if group.Name == groupName {
|
|
||||||
matchingGroups = append(matchingGroups, group)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(matchingGroups) == 0 {
|
|
||||||
return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName)
|
|
||||||
}
|
|
||||||
|
|
||||||
maxPeers := -1
|
|
||||||
var groupWithMostPeers *nbgroup.Group
|
|
||||||
for i, group := range matchingGroups {
|
|
||||||
if len(group.Peers) > maxPeers {
|
|
||||||
maxPeers = len(group.Peers)
|
|
||||||
groupWithMostPeers = matchingGroups[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return groupWithMostPeers, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGroup object of the peers
|
// SaveGroup object of the peers
|
||||||
@@ -262,6 +217,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allGroup, err := account.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if allGroup.ID == groupID {
|
||||||
|
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
|||||||
}
|
}
|
||||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||||
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
|
_, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
|||||||
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||||
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
|
||||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(account)
|
resp := toAccountResponse(accountID, settings)
|
||||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||||
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
|
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(updatedAccount)
|
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, &resp)
|
util.WriteJSONObject(r.Context(), w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAccount is a HTTP DELETE handler to delete an account
|
// DeleteAccount is a HTTP DELETE handler to delete an account
|
||||||
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodDelete {
|
|
||||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
targetAccountID := vars["accountId"]
|
targetAccountID := vars["accountId"]
|
||||||
@@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func toAccountResponse(account *server.Account) *api.Account {
|
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
|
||||||
jwtAllowGroups := account.Settings.JWTAllowGroups
|
jwtAllowGroups := settings.JWTAllowGroups
|
||||||
if jwtAllowGroups == nil {
|
if jwtAllowGroups == nil {
|
||||||
jwtAllowGroups = []string{}
|
jwtAllowGroups = []string{}
|
||||||
}
|
}
|
||||||
|
|
||||||
settings := api.AccountSettings{
|
apiSettings := api.AccountSettings{
|
||||||
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
|
||||||
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
|
||||||
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
|
GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
|
||||||
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
|
JwtGroupsEnabled: &settings.JWTGroupsEnabled,
|
||||||
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
|
JwtGroupsClaimName: &settings.JWTGroupsClaimName,
|
||||||
JwtAllowGroups: &jwtAllowGroups,
|
JwtAllowGroups: &jwtAllowGroups,
|
||||||
RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked,
|
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.Extra != nil {
|
if settings.Extra != nil {
|
||||||
settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled}
|
apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Account{
|
return &api.Account{
|
||||||
Id: account.Id,
|
Id: accountID,
|
||||||
Settings: settings,
|
Settings: apiSettings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,8 +23,11 @@ import (
|
|||||||
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
|
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
|
||||||
return &AccountsHandler{
|
return &AccountsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return account, admin, nil
|
return account.Id, admin.Id, nil
|
||||||
|
},
|
||||||
|
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||||
|
return account.Settings, nil
|
||||||
},
|
},
|
||||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||||
halfYearLimit := 180 * 24 * time.Hour
|
halfYearLimit := 180 * 24 * time.Hour
|
||||||
|
|||||||
@@ -950,7 +950,7 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
|
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
|
||||||
action:
|
action:
|
||||||
description: Action to take upon policy match
|
description: Action to take upon policy match
|
||||||
type: string
|
type: string
|
||||||
|
|||||||
@@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
|
|||||||
// GetDNSSettings returns the DNS settings for the account
|
// GetDNSSettings returns the DNS settings for the account
|
||||||
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
|
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
|
|||||||
// UpdateDNSSettings handles update to DNS settings of an account
|
// UpdateDNSSettings handles update to DNS settings of an account
|
||||||
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
|
|||||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
|
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
|
|||||||
}
|
}
|
||||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
|
return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
|
|||||||
@@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
|
|||||||
// GetAllEvents list of the given account
|
// GetAllEvents list of the given account
|
||||||
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
|
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
|||||||
events[i] = toEventResponse(e)
|
events[i] = toEventResponse(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
|
err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
|
func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
|
||||||
return &EventsHandler{
|
return &EventsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||||
@@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
|
|||||||
}
|
}
|
||||||
return []*activity.Event{}, nil
|
return []*activity.Event{}, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return &server.Account{
|
return claims.AccountId, claims.UserId, nil
|
||||||
Id: claims.AccountId,
|
|
||||||
Domain: "hotmail.com",
|
|
||||||
Users: map[string]*server.User{
|
|
||||||
user.Id: user,
|
|
||||||
},
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||||
return make([]*server.UserInfo, 0), nil
|
return make([]*server.UserInfo, 0), nil
|
||||||
@@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
|||||||
accountID := "test_account"
|
accountID := "test_account"
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user")
|
||||||
events := generateEvents(accountID, adminUser.Id)
|
events := generateEvents(accountID, adminUser.Id)
|
||||||
handler := initEventsTestData(accountID, adminUser, events...)
|
handler := initEventsTestData(accountID, events...)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
|||||||
|
|
||||||
return &GeolocationsHandler{
|
return &GeolocationsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
return claims.AccountId, claims.UserId, nil
|
||||||
return &server.Account{
|
},
|
||||||
Id: claims.AccountId,
|
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||||
Users: map[string]*server.User{
|
return server.NewAdminUser(id), nil
|
||||||
"test_user": user,
|
|
||||||
},
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
geolocationManager: geo,
|
geolocationManager: geo,
|
||||||
|
|||||||
@@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
|
|||||||
|
|
||||||
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
||||||
claims := l.claimsExtractor.FromRequestContext(r)
|
claims := l.claimsExtractor.FromRequestContext(r)
|
||||||
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
|
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := l.accountManager.GetUserByID(r.Context(), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
@@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
|
|||||||
// GetAllGroups list for the account
|
// GetAllGroups list for the account
|
||||||
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
|
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
groupsResponse := make([]*api.Group, 0, len(groups))
|
groupsResponse := make([]*api.Group, 0, len(groups))
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
|
groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group))
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
||||||
@@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
|||||||
// UpdateGroup handles update to a group identified by a given ID
|
// UpdateGroup handles update to a group identified by a given ID
|
||||||
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
eg, ok := account.Groups[groupID]
|
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||||
if !ok {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
if allGroup.ID == groupID {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||||
return
|
return
|
||||||
@@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
ID: groupID,
|
ID: groupID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
Issued: eg.Issued,
|
Issued: existingGroup.Issued,
|
||||||
IntegrationReference: eg.IntegrationReference,
|
IntegrationReference: existingGroup.IntegrationReference,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
|
if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroup handles group creation request
|
// CreateGroup handles group creation request
|
||||||
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
Issued: nbgroup.GroupIssuedAPI,
|
Issued: nbgroup.GroupIssuedAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
|
err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroup handles group deletion request
|
// DeleteGroup handles group deletion request
|
||||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aID := account.Id
|
|
||||||
|
|
||||||
groupID := mux.Vars(r)["groupId"]
|
groupID := mux.Vars(r)["groupId"]
|
||||||
if len(groupID) == 0 {
|
if len(groupID) == 0 {
|
||||||
@@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, ok := err.(*server.GroupLinkError)
|
_, ok := err.(*server.GroupLinkError)
|
||||||
if ok {
|
if ok {
|
||||||
@@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetGroup returns a group
|
// GetGroup returns a group
|
||||||
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID := mux.Vars(r)["groupId"]
|
||||||
|
if len(groupID) == 0 {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r.Method {
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
case http.MethodGet:
|
if err != nil {
|
||||||
groupID := mux.Vars(r)["groupId"]
|
util.WriteError(r.Context(), err, w)
|
||||||
if len(groupID) == 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
|
|
||||||
default:
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
|
func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group {
|
||||||
|
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||||
|
for _, peer := range peers {
|
||||||
|
peersMap[peer.ID] = peer
|
||||||
|
}
|
||||||
|
|
||||||
cache := make(map[string]api.PeerMinimum)
|
cache := make(map[string]api.PeerMinimum)
|
||||||
gr := api.Group{
|
gr := api.Group{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
@@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
|
|||||||
for _, pid := range group.Peers {
|
for _, pid := range group.Peers {
|
||||||
_, ok := cache[pid]
|
_, ok := cache[pid]
|
||||||
if !ok {
|
if !ok {
|
||||||
peer, ok := account.Peers[pid]
|
peer, ok := peersMap[pid]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/magiconair/properties/assert"
|
"github.com/magiconair/properties/assert"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
@@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
|
|||||||
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
||||||
}
|
}
|
||||||
|
|
||||||
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
|
||||||
return &GroupsHandler{
|
return &GroupsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
||||||
@@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
|
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
|
||||||
if groupID != "idofthegroup" {
|
groups := map[string]*nbgroup.Group{
|
||||||
|
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
|
||||||
|
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
|
||||||
|
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range initGroups {
|
||||||
|
groups[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
|
group, ok := groups[groupID]
|
||||||
|
if !ok {
|
||||||
return nil, status.Errorf(status.NotFound, "not found")
|
return nil, status.Errorf(status.NotFound, "not found")
|
||||||
}
|
}
|
||||||
if groupID == "id-jwt-group" {
|
|
||||||
return &nbgroup.Group{
|
return group, nil
|
||||||
ID: "id-jwt-group",
|
|
||||||
Name: "Default Group",
|
|
||||||
Issued: nbgroup.GroupIssuedJWT,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return &nbgroup.Group{
|
|
||||||
ID: "idofthegroup",
|
|
||||||
Name: "Group",
|
|
||||||
Issued: nbgroup.GroupIssuedAPI,
|
|
||||||
}, nil
|
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return &server.Account{
|
return claims.AccountId, claims.UserId, nil
|
||||||
Id: claims.AccountId,
|
},
|
||||||
Domain: "hotmail.com",
|
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
|
||||||
Peers: TestPeers,
|
if groupName == "All" {
|
||||||
Users: map[string]*server.User{
|
return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
|
||||||
user.Id: user,
|
}
|
||||||
},
|
|
||||||
Groups: map[string]*nbgroup.Group{
|
return nil, fmt.Errorf("unknown group name")
|
||||||
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
|
},
|
||||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
|
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||||
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
|
return maps.Values(TestPeers), nil
|
||||||
},
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
||||||
if groupID == "linked-grp" {
|
if groupID == "linked-grp" {
|
||||||
@@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
|
|||||||
Name: "Group",
|
Name: "Group",
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
p := initGroupTestData(group)
|
||||||
p := initGroupTestData(adminUser, group)
|
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
@@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
p := initGroupTestData()
|
||||||
p := initGroupTestData(adminUser)
|
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
@@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
p := initGroupTestData()
|
||||||
p := initGroupTestData(adminUser)
|
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
|||||||
AuthCfg: authCfg,
|
AuthCfg: authCfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
|
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
|
||||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
|
|||||||
// GetAllNameservers returns the list of nameserver groups for the account
|
// GetAllNameservers returns the list of nameserver groups for the account
|
||||||
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
|
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
|
|||||||
// CreateNameserverGroup handles nameserver group creation request
|
// CreateNameserverGroup handles nameserver group creation request
|
||||||
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
|
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||||
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
|
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
// DeleteNameserverGroup handles nameserver group deletion request
|
// DeleteNameserverGroup handles nameserver group deletion request
|
||||||
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
|
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
||||||
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
@@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
|
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
@@ -29,14 +28,6 @@ const (
|
|||||||
testNSGroupAccountID = "test_id"
|
testNSGroupAccountID = "test_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testingNSAccount = &server.Account{
|
|
||||||
Id: testNSGroupAccountID,
|
|
||||||
Domain: "hotmail.com",
|
|
||||||
Users: map[string]*server.User{
|
|
||||||
"test_user": server.NewAdminUser("test_user"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var baseExistingNSGroup = &nbdns.NameServerGroup{
|
var baseExistingNSGroup = &nbdns.NameServerGroup{
|
||||||
ID: existingNSGroupID,
|
ID: existingNSGroupID,
|
||||||
Name: "super",
|
Name: "super",
|
||||||
@@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
|
|||||||
}
|
}
|
||||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return testingNSAccount, testingAccount.Users["test_user"], nil
|
return claims.AccountId, claims.UserId, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
|
|||||||
@@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
|
|||||||
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||||
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
userID := vars["userId"]
|
targetUserID := vars["userId"]
|
||||||
if len(userID) == 0 {
|
if len(userID) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
|
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
||||||
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
||||||
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
|
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||||
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
|
|||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return testAccount, testAccount.Users[existingUserID], nil
|
return claims.AccountId, claims.UserId, nil
|
||||||
},
|
},
|
||||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
@@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
|
|||||||
return jwtclaims.AuthorizationClaims{
|
return jwtclaims.AuthorizationClaims{
|
||||||
UserId: existingUserID,
|
UserId: existingUserID,
|
||||||
Domain: testDomain,
|
Domain: testDomain,
|
||||||
AccountId: testNSGroupAccountID,
|
AccountId: existingAccountID,
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@@ -16,6 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeersHandler is a handler that returns peers of the account
|
// PeersHandler is a handler that returns peers of the account
|
||||||
@@ -75,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
|||||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||||
req := &api.PeerRequest{}
|
req := &api.PeerRequest{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -97,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
|
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(ctx, err, w)
|
util.WriteError(ctx, err, w)
|
||||||
return
|
return
|
||||||
@@ -131,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string,
|
|||||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||||
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -145,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodDelete:
|
case http.MethodDelete:
|
||||||
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
|
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||||
return
|
return
|
||||||
case http.MethodPut:
|
case http.MethodGet, http.MethodPut:
|
||||||
h.updatePeer(r.Context(), account, user, peerID, w, r)
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
return
|
if err != nil {
|
||||||
case http.MethodGet:
|
util.WriteError(r.Context(), err, w)
|
||||||
h.getPeer(r.Context(), account, peerID, user.Id, w)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
|
h.getPeer(r.Context(), account, peerID, userID, w)
|
||||||
|
} else {
|
||||||
|
h.updatePeer(r.Context(), account, userID, peerID, w, r)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||||
@@ -160,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// GetAllPeers returns a list of all peers associated with a provided account
|
// GetAllPeers returns a list of all peers associated with a provided account
|
||||||
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -180,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
dnsDomain := h.accountManager.GetDNSDomain()
|
dnsDomain := h.accountManager.GetDNSDomain()
|
||||||
|
|
||||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||||
for _, peer := range peers {
|
for _, peer := range account.Peers {
|
||||||
peerToReturn, err := h.checkPeerStatus(peer)
|
peerToReturn, err := h.checkPeerStatus(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
@@ -215,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
|
|||||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||||
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -228,6 +229,33 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := account.FindUser(userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the user is regular user and does not own the peer
|
||||||
|
// with the given peerID return an empty list
|
||||||
|
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||||
|
peer, ok := account.Peers[peerID]
|
||||||
|
if !ok {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.UserID != user.Id {
|
||||||
|
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dnsDomain := h.accountManager.GetDNSDomain()
|
dnsDomain := h.accountManager.GetDNSDomain()
|
||||||
|
|
||||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -12,20 +13,29 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
|
||||||
|
|
||||||
"github.com/magiconair/properties/assert"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testPeerID = "test_peer"
|
type ctxKey string
|
||||||
const noUpdateChannelTestPeerID = "no-update-channel"
|
|
||||||
|
const (
|
||||||
|
testPeerID = "test_peer"
|
||||||
|
noUpdateChannelTestPeerID = "no-update-channel"
|
||||||
|
|
||||||
|
adminUser = "admin_user"
|
||||||
|
regularUser = "regular_user"
|
||||||
|
serviceUser = "service_user"
|
||||||
|
userIDKey ctxKey = "user_id"
|
||||||
|
)
|
||||||
|
|
||||||
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||||
return &PeersHandler{
|
return &PeersHandler{
|
||||||
@@ -59,22 +69,61 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
GetDNSDomainFunc: func() string {
|
GetDNSDomainFunc: func() string {
|
||||||
return "netbird.selfhosted"
|
return "netbird.selfhosted"
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
return claims.AccountId, claims.UserId, nil
|
||||||
return &server.Account{
|
},
|
||||||
Id: claims.AccountId,
|
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||||
Domain: "hotmail.com",
|
peersMap := make(map[string]*nbpeer.Peer)
|
||||||
Peers: map[string]*nbpeer.Peer{
|
for _, peer := range peers {
|
||||||
peers[0].ID: peers[0],
|
peersMap[peer.ID] = peer.Copy()
|
||||||
peers[1].ID: peers[1],
|
}
|
||||||
|
|
||||||
|
policy := &server.Policy{
|
||||||
|
ID: "policy",
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: "policy",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*server.PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "rule",
|
||||||
|
Name: "rule",
|
||||||
|
Enabled: true,
|
||||||
|
Action: "accept",
|
||||||
|
Destinations: []string{"group1"},
|
||||||
|
Sources: []string{"group1"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Protocol: "all",
|
||||||
|
Ports: []string{"80"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
srvUser := server.NewRegularUser(serviceUser)
|
||||||
|
srvUser.IsServiceUser = true
|
||||||
|
|
||||||
|
account := &server.Account{
|
||||||
|
Id: accountID,
|
||||||
|
Domain: "hotmail.com",
|
||||||
|
Peers: peersMap,
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": user,
|
adminUser: server.NewAdminUser(adminUser),
|
||||||
|
regularUser: server.NewRegularUser(regularUser),
|
||||||
|
serviceUser: srvUser,
|
||||||
|
},
|
||||||
|
Groups: map[string]*nbgroup.Group{
|
||||||
|
"group1": {
|
||||||
|
ID: "group1",
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: "group1",
|
||||||
|
Issued: "api",
|
||||||
|
Peers: maps.Keys(peersMap),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Settings: &server.Settings{
|
Settings: &server.Settings{
|
||||||
PeerLoginExpirationEnabled: true,
|
PeerLoginExpirationEnabled: true,
|
||||||
PeerLoginExpiration: time.Hour,
|
PeerLoginExpiration: time.Hour,
|
||||||
},
|
},
|
||||||
|
Policies: []*server.Policy{policy},
|
||||||
Network: &server.Network{
|
Network: &server.Network{
|
||||||
Identifier: "ciclqisab2ss43jdn8q0",
|
Identifier: "ciclqisab2ss43jdn8q0",
|
||||||
Net: net.IPNet{
|
Net: net.IPNet{
|
||||||
@@ -83,7 +132,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
},
|
},
|
||||||
Serial: 51,
|
Serial: 51,
|
||||||
},
|
},
|
||||||
}, user, nil
|
}
|
||||||
|
|
||||||
|
return account, nil
|
||||||
},
|
},
|
||||||
HasConnectedChannelFunc: func(peerID string) bool {
|
HasConnectedChannelFunc: func(peerID string) bool {
|
||||||
statuses := make(map[string]struct{})
|
statuses := make(map[string]struct{})
|
||||||
@@ -99,8 +150,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
|
userID := r.Context().Value(userIDKey).(string)
|
||||||
return jwtclaims.AuthorizationClaims{
|
return jwtclaims.AuthorizationClaims{
|
||||||
UserId: "test_user",
|
UserId: userID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
}
|
}
|
||||||
@@ -197,6 +249,8 @@ func TestGetPeers(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
||||||
@@ -227,9 +281,15 @@ func TestGetPeers(t *testing.T) {
|
|||||||
|
|
||||||
// hardcode this check for now as we only have two peers in this suite
|
// hardcode this check for now as we only have two peers in this suite
|
||||||
assert.Equal(t, len(respBody), 2)
|
assert.Equal(t, len(respBody), 2)
|
||||||
assert.Equal(t, respBody[1].Connected, false)
|
|
||||||
|
|
||||||
got = respBody[0]
|
for _, peer := range respBody {
|
||||||
|
if peer.Id == testPeerID {
|
||||||
|
got = peer
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, peer.Connected, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
got = &api.Peer{}
|
got = &api.Peer{}
|
||||||
err = json.Unmarshal(content, got)
|
err = json.Unmarshal(content, got)
|
||||||
@@ -251,3 +311,119 @@ func TestGetPeers(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetAccessiblePeers(t *testing.T) {
|
||||||
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
Key: "key1",
|
||||||
|
IP: net.ParseIP("100.64.0.1"),
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true},
|
||||||
|
Name: "peer1",
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: regularUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer2",
|
||||||
|
Key: "key2",
|
||||||
|
IP: net.ParseIP("100.64.0.2"),
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true},
|
||||||
|
Name: "peer2",
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: adminUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
peer3 := &nbpeer.Peer{
|
||||||
|
ID: "peer3",
|
||||||
|
Key: "key3",
|
||||||
|
IP: net.ParseIP("100.64.0.3"),
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true},
|
||||||
|
Name: "peer3",
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: regularUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
p := initTestMetaData(peer1, peer2, peer3)
|
||||||
|
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
peerID string
|
||||||
|
callerUserID string
|
||||||
|
expectedStatus int
|
||||||
|
expectedPeers []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non admin user can access owned peer",
|
||||||
|
peerID: "peer1",
|
||||||
|
callerUserID: regularUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{"peer2", "peer3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non admin user can't access unowned peer",
|
||||||
|
peerID: "peer2",
|
||||||
|
callerUserID: regularUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin user can access owned peer",
|
||||||
|
peerID: "peer2",
|
||||||
|
callerUserID: adminUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{"peer1", "peer3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin user can access unowned peer",
|
||||||
|
peerID: "peer3",
|
||||||
|
callerUserID: adminUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{"peer1", "peer2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "service user can access unowned peer",
|
||||||
|
peerID: "peer3",
|
||||||
|
callerUserID: serviceUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{"peer1", "peer2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||||
|
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
if res.StatusCode != tc.expectedStatus {
|
||||||
|
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
var accessiblePeers []api.AccessiblePeer
|
||||||
|
err = json.Unmarshal(body, &accessiblePeers)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerIDs := make([]string, len(accessiblePeers))
|
||||||
|
for i, peer := range accessiblePeers {
|
||||||
|
peerIDs[i] = peer.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
@@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
|||||||
// GetAllPolicies list for the account
|
// GetAllPolicies list for the account
|
||||||
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
|
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policies := []*api.Policy{}
|
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
for _, policy := range accountPolicies {
|
if err != nil {
|
||||||
resp := toPolicyResponse(account, policy)
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
policies := make([]*api.Policy, 0, len(listPolicies))
|
||||||
|
for _, policy := range listPolicies {
|
||||||
|
resp := toPolicyResponse(allGroups, policy)
|
||||||
if len(resp.Rules) == 0 {
|
if len(resp.Rules) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
return
|
return
|
||||||
@@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
|||||||
// UpdatePolicy handles update to a policy identified by a given ID
|
// UpdatePolicy handles update to a policy identified by a given ID
|
||||||
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policyIdx := -1
|
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||||
for i, policy := range account.Policies {
|
|
||||||
if policy.ID == policyID {
|
|
||||||
policyIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if policyIdx < 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.savePolicy(w, r, account, user, policyID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePolicy handles policy creation request
|
|
||||||
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.savePolicy(w, r, account, user, "")
|
h.savePolicy(w, r, accountID, userID, policyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePolicy handles policy creation request
|
||||||
|
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.savePolicy(w, r, accountID, userID, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// savePolicy handles policy creation and update
|
// savePolicy handles policy creation and update
|
||||||
func (h *Policies) savePolicy(
|
func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
account *server.Account,
|
|
||||||
user *server.User,
|
|
||||||
policyID string,
|
|
||||||
) {
|
|
||||||
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@@ -127,6 +122,8 @@ func (h *Policies) savePolicy(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isUpdate := policyID != ""
|
||||||
|
|
||||||
if policyID == "" {
|
if policyID == "" {
|
||||||
policyID = xid.New().String()
|
policyID = xid.New().String()
|
||||||
}
|
}
|
||||||
@@ -141,8 +138,8 @@ func (h *Policies) savePolicy(
|
|||||||
pr := server.PolicyRule{
|
pr := server.PolicyRule{
|
||||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||||
Name: rule.Name,
|
Name: rule.Name,
|
||||||
Destinations: groupMinimumsToStrings(account, rule.Destinations),
|
Destinations: rule.Destinations,
|
||||||
Sources: groupMinimumsToStrings(account, rule.Sources),
|
Sources: rule.Sources,
|
||||||
Bidirectional: rule.Bidirectional,
|
Bidirectional: rule.Bidirectional,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,15 +204,21 @@ func (h *Policies) savePolicy(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.SourcePostureChecks != nil {
|
if req.SourcePostureChecks != nil {
|
||||||
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
|
policy.SourcePostureChecks = *req.SourcePostureChecks
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
|
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toPolicyResponse(account, &policy)
|
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := toPolicyResponse(allGroups, &policy)
|
||||||
if len(resp.Rules) == 0 {
|
if len(resp.Rules) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
return
|
return
|
||||||
@@ -227,12 +230,11 @@ func (h *Policies) savePolicy(
|
|||||||
// DeletePolicy handles policy deletion request
|
// DeletePolicy handles policy deletion request
|
||||||
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aID := account.Id
|
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
policyID := vars["policyId"]
|
policyID := vars["policyId"]
|
||||||
@@ -241,7 +243,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
|
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -252,40 +254,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetPolicy handles a group Get request identified by ID
|
// GetPolicy handles a group Get request identified by ID
|
||||||
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r.Method {
|
vars := mux.Vars(r)
|
||||||
case http.MethodGet:
|
policyID := vars["policyId"]
|
||||||
vars := mux.Vars(r)
|
if len(policyID) == 0 {
|
||||||
policyID := vars["policyId"]
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||||
if len(policyID) == 0 {
|
return
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := toPolicyResponse(account, policy)
|
|
||||||
if len(resp.Rules) == 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, resp)
|
|
||||||
default:
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := toPolicyResponse(allGroups, policy)
|
||||||
|
if len(resp.Rules) == 0 {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
|
func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
|
||||||
|
groupsMap := make(map[string]*nbgroup.Group)
|
||||||
|
for _, group := range groups {
|
||||||
|
groupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
cache := make(map[string]api.GroupMinimum)
|
cache := make(map[string]api.GroupMinimum)
|
||||||
ap := &api.Policy{
|
ap := &api.Policy{
|
||||||
Id: &policy.ID,
|
Id: &policy.ID,
|
||||||
@@ -306,16 +314,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
|||||||
Protocol: api.PolicyRuleProtocol(r.Protocol),
|
Protocol: api.PolicyRuleProtocol(r.Protocol),
|
||||||
Action: api.PolicyRuleAction(r.Action),
|
Action: api.PolicyRuleAction(r.Action),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(r.Ports) != 0 {
|
if len(r.Ports) != 0 {
|
||||||
portsCopy := r.Ports
|
portsCopy := r.Ports
|
||||||
rule.Ports = &portsCopy
|
rule.Ports = &portsCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, gid := range r.Sources {
|
for _, gid := range r.Sources {
|
||||||
_, ok := cache[gid]
|
_, ok := cache[gid]
|
||||||
if ok {
|
if ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if group, ok := account.Groups[gid]; ok {
|
if group, ok := groupsMap[gid]; ok {
|
||||||
minimum := api.GroupMinimum{
|
minimum := api.GroupMinimum{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
@@ -325,13 +335,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
|||||||
cache[gid] = minimum
|
cache[gid] = minimum
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, gid := range r.Destinations {
|
for _, gid := range r.Destinations {
|
||||||
cachedMinimum, ok := cache[gid]
|
cachedMinimum, ok := cache[gid]
|
||||||
if ok {
|
if ok {
|
||||||
rule.Destinations = append(rule.Destinations, cachedMinimum)
|
rule.Destinations = append(rule.Destinations, cachedMinimum)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if group, ok := account.Groups[gid]; ok {
|
if group, ok := groupsMap[gid]; ok {
|
||||||
minimum := api.GroupMinimum{
|
minimum := api.GroupMinimum{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
@@ -345,28 +356,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
|||||||
}
|
}
|
||||||
return ap
|
return ap
|
||||||
}
|
}
|
||||||
|
|
||||||
func groupMinimumsToStrings(account *server.Account, gm []string) []string {
|
|
||||||
result := make([]string, 0, len(gm))
|
|
||||||
for _, g := range gm {
|
|
||||||
if _, ok := account.Groups[g]; !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result = append(result, g)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
|
|
||||||
result := make([]string, 0, len(postureChecksIds))
|
|
||||||
for _, id := range postureChecksIds {
|
|
||||||
for _, postureCheck := range account.PostureChecks {
|
|
||||||
if id == postureCheck.ID {
|
|
||||||
result = append(result, id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
|||||||
}
|
}
|
||||||
return policy, nil
|
return policy, nil
|
||||||
},
|
},
|
||||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
|
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
|
||||||
if !strings.HasPrefix(policy.ID, "id-") {
|
if !strings.HasPrefix(policy.ID, "id-") {
|
||||||
policy.ID = "id-was-set"
|
policy.ID = "id-was-set"
|
||||||
policy.Rules[0].ID = "id-was-set"
|
policy.Rules[0].ID = "id-was-set"
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||||
|
},
|
||||||
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
|
return claims.AccountId, claims.UserId, nil
|
||||||
|
},
|
||||||
|
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||||
|
user := server.NewAdminUser(userID)
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: accountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
Policies: []*server.Policy{
|
Policies: []*server.Policy{
|
||||||
{ID: "id-existed"},
|
{ID: "id-existed"},
|
||||||
@@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
|||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": user,
|
"test_user": user,
|
||||||
},
|
},
|
||||||
}, user, nil
|
}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
|
|||||||
@@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
|
|||||||
// GetAllPostureChecks list for the account
|
// GetAllPostureChecks list for the account
|
||||||
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
|
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := []*api.PostureCheck{}
|
postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks))
|
||||||
for _, postureCheck := range accountPostureChecks {
|
for _, postureCheck := range listPostureChecks {
|
||||||
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
|
|||||||
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
||||||
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecksIdx := -1
|
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||||
for i, postureCheck := range account.PostureChecks {
|
if err != nil {
|
||||||
if postureCheck.ID == postureChecksID {
|
util.WriteError(r.Context(), err, w)
|
||||||
postureChecksIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if postureChecksIdx < 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.savePostureChecks(w, r, account, user, postureChecksID)
|
p.savePostureChecks(w, r, accountID, userID, postureChecksID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePostureCheck handles posture check creation request
|
// CreatePostureCheck handles posture check creation request
|
||||||
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.savePostureChecks(w, r, account, user, "")
|
p.savePostureChecks(w, r, accountID, userID, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPostureCheck handles a posture check Get request identified by ID
|
// GetPostureCheck handles a posture check Get request identified by ID
|
||||||
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
|
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
|||||||
// DeletePostureCheck handles posture check deletion request
|
// DeletePostureCheck handles posture check deletion request
|
||||||
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
|
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
// savePostureChecks handles posture checks create and update
|
// savePostureChecks handles posture checks create and update
|
||||||
func (p *PostureChecksHandler) savePostureChecks(
|
func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
account *server.Account,
|
|
||||||
user *server.User,
|
|
||||||
postureChecksID string,
|
|
||||||
) {
|
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
req api.PostureCheckUpdate
|
req api.PostureCheckUpdate
|
||||||
@@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
|
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
|||||||
}
|
}
|
||||||
return accountPostureChecks, nil
|
return accountPostureChecks, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
return claims.AccountId, claims.UserId, nil
|
||||||
return &server.Account{
|
|
||||||
Id: claims.AccountId,
|
|
||||||
Users: map[string]*server.User{
|
|
||||||
"test_user": user,
|
|
||||||
},
|
|
||||||
PostureChecks: postureChecks,
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
geolocationManager: &geolocation.Geolocation{},
|
geolocationManager: &geolocation.Geolocation{},
|
||||||
|
|||||||
@@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
|
|||||||
// GetAllRoutes returns the list of routes for the account
|
// GetAllRoutes returns the list of routes for the account
|
||||||
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
|
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
|||||||
// CreateRoute handles route creation request
|
// CreateRoute handles route creation request
|
||||||
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
peerGroupIds = *req.PeerGroups
|
peerGroupIds = *req.PeerGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do not allow non-Linux peers
|
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||||
if peer := account.GetPeer(peerId); peer != nil {
|
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
|
||||||
if peer.Meta.GoOS != "linux" {
|
)
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
|||||||
// UpdateRoute handles update to a route identified by a given ID
|
// UpdateRoute handles update to a route identified by a given ID
|
||||||
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
peerID = *req.Peer
|
peerID = *req.Peer
|
||||||
}
|
}
|
||||||
|
|
||||||
// do not allow non Linux peers
|
|
||||||
if peer := account.GetPeer(peerID); peer != nil {
|
|
||||||
if peer.Meta.GoOS != "linux" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute := &route.Route{
|
newRoute := &route.Route{
|
||||||
ID: route.ID(routeID),
|
ID: route.ID(routeID),
|
||||||
NetID: route.NetID(req.NetworkId),
|
NetID: route.NetID(req.NetworkId),
|
||||||
@@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
newRoute.PeerGroups = *req.PeerGroups
|
newRoute.PeerGroups = *req.PeerGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
|
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
// DeleteRoute handles route deletion request
|
// DeleteRoute handles route deletion request
|
||||||
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetRoute handles a route Get request identified by ID
|
// GetRoute handles a route Get request identified by ID
|
||||||
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
|
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
|
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
|
||||||
}
|
}
|
||||||
|
if peerID != "" {
|
||||||
|
if peerID == nonLinuxExistingPeerID {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &route.Route{
|
return &route.Route{
|
||||||
ID: existingRouteID,
|
ID: existingRouteID,
|
||||||
NetID: netID,
|
NetID: netID,
|
||||||
@@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
if r.Peer == notFoundPeerID {
|
if r.Peer == notFoundPeerID {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
|
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.Peer == nonLinuxExistingPeerID {
|
||||||
|
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
|
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
|
||||||
@@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return testingAccount, testingAccount.Users["test_user"], nil
|
//return testingAccount, testingAccount.Users["test_user"], nil
|
||||||
|
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
|
|||||||
// CreateSetupKey is a POST requests that creates a new SetupKey
|
// CreateSetupKey is a POST requests that creates a new SetupKey
|
||||||
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
if req.Ephemeral != nil {
|
if req.Ephemeral != nil {
|
||||||
ephemeral = *req.Ephemeral
|
ephemeral = *req.Ephemeral
|
||||||
}
|
}
|
||||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||||
req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
|
req.AutoGroups, req.UsageLimit, userID, ephemeral)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
// GetSetupKey is a GET request to get a SetupKey by ID
|
// GetSetupKey is a GET request to get a SetupKey by ID
|
||||||
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
|
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
// UpdateSetupKey is a PUT request to update server.SetupKey
|
// UpdateSetupKey is a PUT request to update server.SetupKey
|
||||||
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
newKey.Name = req.Name
|
newKey.Name = req.Name
|
||||||
newKey.Id = keyID
|
newKey.Id = keyID
|
||||||
|
|
||||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
|
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
||||||
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
|
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
@@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
|||||||
) *SetupKeysHandler {
|
) *SetupKeysHandler {
|
||||||
return &SetupKeysHandler{
|
return &SetupKeysHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return &server.Account{
|
return claims.AccountId, claims.UserId, nil
|
||||||
Id: testAccountID,
|
|
||||||
Domain: "hotmail.com",
|
|
||||||
Users: map[string]*server.User{
|
|
||||||
user.Id: user,
|
|
||||||
},
|
|
||||||
SetupKeys: map[string]*server.SetupKey{
|
|
||||||
defaultKey.Key: defaultKey,
|
|
||||||
},
|
|
||||||
Groups: map[string]*nbgroup.Group{
|
|
||||||
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
|
||||||
"id-all": {ID: "id-all", Name: "All"},
|
|
||||||
},
|
|
||||||
}, user, nil
|
|
||||||
},
|
},
|
||||||
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||||
_ int, _ string, ephemeral bool,
|
_ int, _ string, ephemeral bool,
|
||||||
|
|||||||
@@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
userID := vars["userId"]
|
targetUserID := vars["userId"]
|
||||||
if len(userID) == 0 {
|
if len(targetUserID) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
existingUser, ok := account.Users[userID]
|
existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
|
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
|
||||||
Id: userID,
|
Id: targetUserID,
|
||||||
Role: userRole,
|
Role: userRole,
|
||||||
AutoGroups: req.AutoGroups,
|
AutoGroups: req.AutoGroups,
|
||||||
Blocked: req.IsBlocked,
|
Blocked: req.IsBlocked,
|
||||||
@@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
|
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
name = *req.Name
|
name = *req.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
|
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
|
||||||
Email: email,
|
Email: email,
|
||||||
Name: name,
|
Name: name,
|
||||||
Role: req.Role,
|
Role: req.Role,
|
||||||
@@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
|
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
|
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
|
|||||||
func initUsersTestData() *UsersHandler {
|
func initUsersTestData() *UsersHandler {
|
||||||
return &UsersHandler{
|
return &UsersHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
|
return usersTestAccount.Id, claims.UserId, nil
|
||||||
|
},
|
||||||
|
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||||
|
return usersTestAccount.Users[id], nil
|
||||||
},
|
},
|
||||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||||
users := make([]*server.UserInfo, 0)
|
users := make([]*server.UserInfo, 0)
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ package idp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -97,6 +99,42 @@ type zitadelUserResponse struct {
|
|||||||
PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"`
|
PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readZitadelError parses errors returned by the zitadel APIs from a response.
|
||||||
|
func readZitadelError(body io.ReadCloser) error {
|
||||||
|
bodyBytes, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
helper := JsonParser{}
|
||||||
|
var target map[string]interface{}
|
||||||
|
err = helper.Unmarshal(bodyBytes, &target)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error unparsable body: %s", string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure keys are ordered for consistent logging behaviour.
|
||||||
|
errorKeys := make([]string, 0, len(target))
|
||||||
|
for k := range target {
|
||||||
|
errorKeys = append(errorKeys, k)
|
||||||
|
}
|
||||||
|
slices.Sort(errorKeys)
|
||||||
|
|
||||||
|
var errsOut []string
|
||||||
|
for _, k := range errorKeys {
|
||||||
|
if _, isEmbedded := target[k].(map[string]interface{}); isEmbedded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
errsOut = append(errsOut, fmt.Sprintf("%s: %v", k, target[k]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errsOut) == 0 {
|
||||||
|
return errors.New("unknown error")
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.New(strings.Join(errsOut, " "))
|
||||||
|
}
|
||||||
|
|
||||||
// NewZitadelManager creates a new instance of the ZitadelManager.
|
// NewZitadelManager creates a new instance of the ZitadelManager.
|
||||||
func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
|
func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
@@ -176,7 +214,8 @@ func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Respon
|
|||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode)
|
zErr := readZitadelError(resp.Body)
|
||||||
|
return nil, fmt.Errorf("unable to get zitadel token, statusCode %d, zitadel: %w", resp.StatusCode, zErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
@@ -489,7 +528,9 @@ func (zm *ZitadelManager) post(ctx context.Context, resource string, body string
|
|||||||
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
|
zErr := readZitadelError(resp.Body)
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to post %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.ReadAll(resp.Body)
|
return io.ReadAll(resp.Body)
|
||||||
@@ -561,7 +602,9 @@ func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values
|
|||||||
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
zErr := readZitadelError(resp.Body)
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to get %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.ReadAll(resp.Body)
|
return io.ReadAll(resp.Body)
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ func TestNewZitadelManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestZitadelRequestJWTToken(t *testing.T) {
|
func TestZitadelRequestJWTToken(t *testing.T) {
|
||||||
|
|
||||||
type requestJWTTokenTest struct {
|
type requestJWTTokenTest struct {
|
||||||
name string
|
name string
|
||||||
inputCode int
|
inputCode int
|
||||||
@@ -88,15 +87,14 @@ func TestZitadelRequestJWTToken(t *testing.T) {
|
|||||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||||
name: "Request Bad Status Code",
|
name: "Request Bad Status Code",
|
||||||
inputCode: 400,
|
inputCode: 400,
|
||||||
inputRespBody: "{}",
|
inputRespBody: "{\"error\": \"invalid_scope\", \"error_description\":\"openid missing\"}",
|
||||||
helper: JsonParser{},
|
helper: JsonParser{},
|
||||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
|
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: error: invalid_scope error_description: openid missing"),
|
||||||
expectedToken: "",
|
expectedToken: "",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
jwtReqClient := mockHTTPClient{
|
jwtReqClient := mockHTTPClient{
|
||||||
resBody: testCase.inputRespBody,
|
resBody: testCase.inputRespBody,
|
||||||
code: testCase.inputCode,
|
code: testCase.inputCode,
|
||||||
@@ -156,7 +154,7 @@ func TestZitadelParseRequestJWTResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||||
name: "Parse Bad json JWT Body",
|
name: "Parse Bad json JWT Body",
|
||||||
inputRespBody: "",
|
inputRespBody: "{}",
|
||||||
helper: JsonParser{},
|
helper: JsonParser{},
|
||||||
expectedToken: "",
|
expectedToken: "",
|
||||||
expectedExpiresIn: 0,
|
expectedExpiresIn: 0,
|
||||||
@@ -254,7 +252,7 @@ func TestZitadelAuthenticate(t *testing.T) {
|
|||||||
inputCode: 400,
|
inputCode: 400,
|
||||||
inputResBody: "{}",
|
inputResBody: "{}",
|
||||||
helper: JsonParser{},
|
helper: JsonParser{},
|
||||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
|
expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: unknown error"),
|
||||||
expectedCode: 200,
|
expectedCode: 200,
|
||||||
expectedToken: "",
|
expectedToken: "",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,10 +23,11 @@ import (
|
|||||||
|
|
||||||
type MockAccountManager struct {
|
type MockAccountManager struct {
|
||||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
|
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
|
||||||
|
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
|
||||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
|
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
|
||||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@@ -48,7 +49,7 @@ type MockAccountManager struct {
|
|||||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error
|
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
||||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||||
@@ -79,7 +80,7 @@ type MockAccountManager struct {
|
|||||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||||
GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||||
GetDNSDomainFunc func() string
|
GetDNSDomainFunc func() string
|
||||||
@@ -105,6 +106,9 @@ type MockAccountManager struct {
|
|||||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||||
|
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
|
||||||
|
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
|
||||||
|
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||||
@@ -190,16 +194,14 @@ func (am *MockAccountManager) CreateSetupKey(
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
|
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
|
||||||
ctx context.Context, userId, accountId, domain string,
|
if am.GetAccountIDByUserOrAccountIdFunc != nil {
|
||||||
) (*server.Account, error) {
|
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
||||||
if am.GetAccountByUserOrAccountIdFunc != nil {
|
|
||||||
return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(
|
return "", status.Errorf(
|
||||||
codes.Unimplemented,
|
codes.Unimplemented,
|
||||||
"method GetAccountByUserOrAccountID is not implemented",
|
"method GetAccountIDByUserOrAccountID is not implemented",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,9 +379,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error {
|
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
|
||||||
if am.SavePolicyFunc != nil {
|
if am.SavePolicyFunc != nil {
|
||||||
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||||
}
|
}
|
||||||
@@ -601,14 +603,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
|
// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
|
func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
error,
|
if am.GetAccountIDFromTokenFunc != nil {
|
||||||
) {
|
return am.GetAccountIDFromTokenFunc(ctx, claims)
|
||||||
if am.GetAccountFromTokenFunc != nil {
|
|
||||||
return am.GetAccountFromTokenFunc(ctx, claims)
|
|
||||||
}
|
}
|
||||||
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||||
@@ -802,3 +802,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
|
|||||||
}
|
}
|
||||||
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
|
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountByID mocks GetAccountByID of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||||
|
if am.GetAccountByIDFunc != nil {
|
||||||
|
return am.GetAccountByIDFunc(ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByID mocks GetUserByID of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
|
||||||
|
if am.GetUserByIDFunc != nil {
|
||||||
|
return am.GetUserByIDFunc(ctx, id)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||||
|
if am.GetAccountSettingsFunc != nil {
|
||||||
|
return am.GetAccountSettingsFunc(ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) {
|
||||||
|
if am.GetAccountFunc != nil {
|
||||||
|
return am.GetAccountFunc(ctx, accountID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
|||||||
|
|
||||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups")
|
|
||||||
}
|
|
||||||
|
|
||||||
nsGroup, found := account.NameServerGroups[nsGroupID]
|
|
||||||
if found {
|
|
||||||
return nsGroup.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameServerGroup creates and saves a new nameserver group
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
@@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
|||||||
|
|
||||||
// ListNameServerGroups returns a list of nameserver groups from account
|
// ListNameServerGroups returns a list of nameserver groups from account
|
||||||
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
|
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||||
for _, item := range account.NameServerGroups {
|
|
||||||
nsGroups = append(nsGroups, item.Copy())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nsGroups, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
||||||
|
|||||||
@@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||||
return
|
return
|
||||||
@@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
policy.Enabled = false
|
policy.Enabled = false
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -314,34 +315,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
|||||||
|
|
||||||
// GetPolicy from the store
|
// GetPolicy from the store
|
||||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, policy := range account.Policies {
|
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||||
if policy.ID == policyID {
|
|
||||||
return policy, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy in the store
|
// SavePolicy in the store
|
||||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
|
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@@ -350,7 +337,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists := am.savePolicy(account, policy)
|
if err = am.savePolicy(account, policy, isUpdate); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
@@ -358,7 +347,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
|
|
||||||
action := activity.PolicyAdded
|
action := activity.PolicyAdded
|
||||||
if exists {
|
if isUpdate {
|
||||||
action = activity.PolicyUpdated
|
action = activity.PolicyUpdated
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
@@ -397,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
|||||||
|
|
||||||
// ListPolicies from the store
|
// ListPolicies from the store
|
||||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Policies, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||||
@@ -434,18 +415,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
|
|||||||
return policy, nil
|
return policy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) {
|
// savePolicy saves or updates a policy in the given account.
|
||||||
for i, p := range account.Policies {
|
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||||
if p.ID == policy.ID {
|
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
|
||||||
account.Policies[i] = policy
|
for index, rule := range policyToSave.Rules {
|
||||||
exists = true
|
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||||
break
|
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||||
|
policyToSave.Rules[index] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
if policyToSave.SourcePostureChecks != nil {
|
||||||
|
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isUpdate {
|
||||||
|
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||||
|
if policyIdx < 0 {
|
||||||
|
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update the existing policy
|
||||||
|
account.Policies[policyIdx] = policyToSave
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
if !exists {
|
|
||||||
account.Policies = append(account.Policies, policy)
|
// Add the new policy to the account
|
||||||
}
|
account.Policies = append(account.Policies, policyToSave)
|
||||||
return
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||||
@@ -560,3 +557,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterValidPostureChecks filters and returns the posture check IDs from the given list
|
||||||
|
// that are valid within the provided account.
|
||||||
|
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
|
||||||
|
result := make([]string, 0, len(postureChecksIds))
|
||||||
|
for _, id := range postureChecksIds {
|
||||||
|
for _, postureCheck := range account.PostureChecks {
|
||||||
|
if id == postureCheck.ID {
|
||||||
|
result = append(result, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
|
||||||
|
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
|
||||||
|
result := make([]string, 0, len(groupIDs))
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
if _, exists := account.Groups[groupID]; exists {
|
||||||
|
result = append(result, groupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,30 +15,16 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, postureChecks := range account.PostureChecks {
|
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||||
if postureChecks.ID == postureChecksID {
|
|
||||||
return postureChecks, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||||
@@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.PostureChecks, nil
|
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
||||||
|
|||||||
@@ -17,29 +17,16 @@ import (
|
|||||||
|
|
||||||
// GetRoute gets a route object from account and route IDs
|
// GetRoute gets a route object from account and route IDs
|
||||||
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
wantedRoute, found := account.Routes[routeID]
|
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
|
||||||
if found {
|
|
||||||
return wantedRoute, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||||
@@ -134,6 +121,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do not allow non-Linux peers
|
||||||
|
if peer := account.GetPeer(peerID); peer != nil {
|
||||||
|
if peer.Meta.GoOS != "linux" {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(domains) > 0 && prefix.IsValid() {
|
if len(domains) > 0 && prefix.IsValid() {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||||
}
|
}
|
||||||
@@ -234,6 +228,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do not allow non-Linux peers
|
||||||
|
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
|
||||||
|
if peer.Meta.GoOS != "linux" {
|
||||||
|
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||||
}
|
}
|
||||||
@@ -311,29 +312,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
|||||||
|
|
||||||
// ListRoutes returns a list of routes from account
|
// ListRoutes returns a list of routes from account
|
||||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := make([]*route.Route, 0, len(account.Routes))
|
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||||
for _, item := range account.Routes {
|
|
||||||
routes = append(routes, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
return routes, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||||
|
|||||||
@@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
|||||||
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
||||||
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
||||||
|
|
||||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
|
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
||||||
|
|||||||
@@ -330,26 +330,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
|||||||
|
|
||||||
// ListSetupKeys returns a list of all setup keys of the account
|
// ListSetupKeys returns a list of all setup keys of the account
|
||||||
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
|
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
|
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
keys := make([]*SetupKey, 0, len(setupKeys))
|
||||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
|
for _, key := range setupKeys {
|
||||||
}
|
|
||||||
|
|
||||||
keys := make([]*SetupKey, 0, len(account.SetupKeys))
|
|
||||||
for _, key := range account.SetupKeys {
|
|
||||||
var k *SetupKey
|
var k *SetupKey
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
if !user.IsAdminOrServiceUser() {
|
||||||
k = key.HiddenCopy(999)
|
k = key.HiddenCopy(999)
|
||||||
} else {
|
} else {
|
||||||
k = key.Copy()
|
k = key.Copy()
|
||||||
@@ -362,44 +360,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
|||||||
|
|
||||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||||
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
|
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
|
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
|
||||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
|
|
||||||
}
|
|
||||||
|
|
||||||
var foundKey *SetupKey
|
|
||||||
for _, key := range account.SetupKeys {
|
|
||||||
if key.Id == keyID {
|
|
||||||
foundKey = key.Copy()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if foundKey == nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
||||||
if foundKey.UpdatedAt.IsZero() {
|
if setupKey.UpdatedAt.IsZero() {
|
||||||
foundKey.UpdatedAt = foundKey.CreatedAt
|
setupKey.UpdatedAt = setupKey.CreatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
if !user.IsAdminOrServiceUser() {
|
||||||
foundKey = foundKey.HiddenCopy(999)
|
setupKey = setupKey.HiddenCopy(999)
|
||||||
}
|
}
|
||||||
|
|
||||||
return foundKey, nil
|
return setupKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ const (
|
|||||||
idQueryCondition = "id = ?"
|
idQueryCondition = "id = ?"
|
||||||
keyQueryCondition = "key = ?"
|
keyQueryCondition = "key = ?"
|
||||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||||
|
accountIDCondition = "account_id = ?"
|
||||||
peerNotFoundFMT = "peer %s not found"
|
peerNotFoundFMT = "peer %s not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -399,20 +400,30 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
|
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
|
||||||
var account Account
|
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||||
|
if err != nil {
|
||||||
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
return nil, err
|
||||||
strings.ToLower(domain), true, PrivateCategory)
|
|
||||||
if result.Error != nil {
|
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
|
||||||
}
|
|
||||||
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: rework to not call GetAccount
|
// TODO: rework to not call GetAccount
|
||||||
return s.GetAccount(ctx, account.Id)
|
return s.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||||
|
var accountID string
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
|
||||||
|
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||||
|
strings.ToLower(domain), true, PrivateCategory,
|
||||||
|
).First(&accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
||||||
|
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return accountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||||
@@ -478,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
|||||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
First(&user, idQueryCondition, userID)
|
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewUserNotFoundError(userID)
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
@@ -491,7 +502,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||||
var groups []*nbgroup.Group
|
var groups []*nbgroup.Group
|
||||||
result := s.db.Find(&groups, idQueryCondition, accountID)
|
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||||
@@ -661,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||||
var user User
|
|
||||||
var accountID string
|
var accountID string
|
||||||
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
@@ -1024,3 +1034,156 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
|||||||
db: tx,
|
db: tx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetDB() *gorm.DB {
|
||||||
|
return s.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
|
||||||
|
var accountDNSSettings AccountDNSSettings
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
||||||
|
}
|
||||||
|
return &accountDNSSettings.DNSSettings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountExists checks whether an account exists by the given ID.
|
||||||
|
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||||
|
var accountID string
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
Select("id").First(&accountID, idQueryCondition, id)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return accountID != "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||||
|
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||||
|
var account Account
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||||
|
Where(idQueryCondition, accountID).First(&account)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return "", "", status.Errorf(status.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.Domain, account.DomainCategory, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroupByID retrieves a group by ID and account ID.
|
||||||
|
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
||||||
|
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroupByName retrieves a group by name and account ID.
|
||||||
|
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
|
var group nbgroup.Group
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||||
|
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "group not found")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
|
||||||
|
}
|
||||||
|
return &group, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountPolicies retrieves policies for an account.
|
||||||
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||||
|
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||||
|
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
||||||
|
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||||
|
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||||
|
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||||
|
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
||||||
|
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountRoutes retrieves network routes for an account.
|
||||||
|
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||||
|
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRouteByID retrieves a route by its ID and account ID.
|
||||||
|
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
||||||
|
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||||
|
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||||
|
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||||
|
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
|
||||||
|
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||||
|
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||||
|
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||||
|
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
|
||||||
|
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRecords retrieves records from the database based on the account ID.
|
||||||
|
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||||
|
var record []T
|
||||||
|
|
||||||
|
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||||
|
recordType := parts[len(parts)-1]
|
||||||
|
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRecordByID retrieves a record by its ID and account ID from the database.
|
||||||
|
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
|
||||||
|
var record T
|
||||||
|
|
||||||
|
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||||
|
recordType := parts[len(parts)-1]
|
||||||
|
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
|
||||||
|
}
|
||||||
|
return &record, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
@@ -39,53 +40,81 @@ const (
|
|||||||
type Store interface {
|
type Store interface {
|
||||||
GetAllAccounts(ctx context.Context) []*Account
|
GetAllAccounts(ctx context.Context) []*Account
|
||||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||||
DeleteAccount(ctx context.Context, account *Account) error
|
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
||||||
|
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||||
GetAccountIDByUserID(peerKey string) (string, error)
|
GetAccountIDByUserID(userID string) (string, error)
|
||||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||||
|
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||||
|
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||||
|
SaveAccount(ctx context.Context, account *Account) error
|
||||||
|
DeleteAccount(ctx context.Context, account *Account) error
|
||||||
|
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
|
||||||
SaveAccount(ctx context.Context, account *Account) error
|
|
||||||
SaveUsers(accountID string, users map[string]*User) error
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
|
|
||||||
|
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||||
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||||
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||||
|
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||||
|
|
||||||
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||||
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||||
|
|
||||||
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||||
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||||
|
|
||||||
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
|
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||||
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||||
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||||
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
|
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
|
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||||
|
|
||||||
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||||
|
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||||
|
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||||
|
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||||
|
|
||||||
|
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||||
|
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||||
|
|
||||||
|
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||||
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||||
|
|
||||||
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
|
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||||
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||||
|
|
||||||
GetInstallationID() string
|
GetInstallationID() string
|
||||||
SaveInstallationID(ctx context.Context, ID string) error
|
SaveInstallationID(ctx context.Context, ID string) error
|
||||||
|
|
||||||
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
||||||
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
||||||
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
||||||
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
||||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||||
AcquireGlobalLock(ctx context.Context) func()
|
AcquireGlobalLock(ctx context.Context) func()
|
||||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
|
||||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
|
||||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
|
||||||
// Close should close the store persisting all unsaved data.
|
// Close should close the store persisting all unsaved data.
|
||||||
Close(ctx context.Context) error
|
Close(ctx context.Context) error
|
||||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||||
// This is also a method of metrics.DataSource interface.
|
// This is also a method of metrics.DataSource interface.
|
||||||
GetStoreEngine() StoreEngine
|
GetStoreEngine() StoreEngine
|
||||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
|
||||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
|
||||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
|
||||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
|
||||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
|
||||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
|
||||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
|
||||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
|
||||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
|
||||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
|
||||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -94,6 +94,11 @@ func (u *User) HasAdminPower() bool {
|
|||||||
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsAdminOrServiceUser checks if the user has admin power or is a service user.
|
||||||
|
func (u *User) IsAdminOrServiceUser() bool {
|
||||||
|
return u.HasAdminPower() || u.IsServiceUser
|
||||||
|
}
|
||||||
|
|
||||||
// ToUserInfo converts a User object to a UserInfo object.
|
// ToUserInfo converts a User object to a UserInfo object.
|
||||||
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
|
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
|
||||||
autoGroups := u.AutoGroups
|
autoGroups := u.AutoGroups
|
||||||
@@ -357,39 +362,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
return newUser.ToUserInfo(idpUser, account.Settings)
|
return newUser.ToUserInfo(idpUser, account.Settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||||
|
return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
|
||||||
|
}
|
||||||
|
|
||||||
// GetUser looks up a user by provided authorization claims.
|
// GetUser looks up a user by provided authorization claims.
|
||||||
// It will also create an account if didn't exist for this user before.
|
// It will also create an account if didn't exist for this user before.
|
||||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get an account from store %v", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, ok := account.Users[claims.UserId]
|
// this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC
|
|
||||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||||
|
|
||||||
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
|
err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if newLogin {
|
if newLogin {
|
||||||
meta := map[string]any{"timestamp": claims.LastLogin}
|
meta := map[string]any{"timestamp": claims.LastLogin}
|
||||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta)
|
am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
@@ -642,63 +643,48 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
|||||||
|
|
||||||
// GetPAT returns a specific PAT from a user
|
// GetPAT returns a specific PAT from a user
|
||||||
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, ok := account.Users[targetUserID]
|
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser, ok := account.Users[initiatorUserID]
|
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
|
||||||
if !ok {
|
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
for _, pat := range targetUser.PATsG {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
|
if pat.ID == tokenID {
|
||||||
|
return pat.Copy(), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pat := targetUser.PATs[tokenID]
|
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||||
if pat == nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return pat, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPATs returns all PATs for a user
|
// GetAllPATs returns all PATs for a user
|
||||||
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, ok := account.Users[targetUserID]
|
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser, ok := account.Users[initiatorUserID]
|
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
var pats []*PersonalAccessToken
|
pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG))
|
||||||
for _, pat := range targetUser.PATs {
|
for _, pat := range targetUser.PATsG {
|
||||||
pats = append(pats, pat)
|
pats = append(pats, pat.Copy())
|
||||||
}
|
}
|
||||||
|
|
||||||
return pats, nil
|
return pats, nil
|
||||||
|
|||||||
@@ -199,7 +199,8 @@ func TestUser_GetPAT(t *testing.T) {
|
|||||||
defer store.Close(context.Background())
|
defer store.Close(context.Background())
|
||||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||||
account.Users[mockUserID] = &User{
|
account.Users[mockUserID] = &User{
|
||||||
Id: mockUserID,
|
Id: mockUserID,
|
||||||
|
AccountID: mockAccountID,
|
||||||
PATs: map[string]*PersonalAccessToken{
|
PATs: map[string]*PersonalAccessToken{
|
||||||
mockTokenID1: {
|
mockTokenID1: {
|
||||||
ID: mockTokenID1,
|
ID: mockTokenID1,
|
||||||
@@ -231,7 +232,8 @@ func TestUser_GetAllPATs(t *testing.T) {
|
|||||||
defer store.Close(context.Background())
|
defer store.Close(context.Background())
|
||||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||||
account.Users[mockUserID] = &User{
|
account.Users[mockUserID] = &User{
|
||||||
Id: mockUserID,
|
Id: mockUserID,
|
||||||
|
AccountID: mockAccountID,
|
||||||
PATs: map[string]*PersonalAccessToken{
|
PATs: map[string]*PersonalAccessToken{
|
||||||
mockTokenID1: {
|
mockTokenID1: {
|
||||||
ID: mockTokenID1,
|
ID: mockTokenID1,
|
||||||
@@ -796,7 +798,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
acc, err := am.Store.GetAccount(context.Background(), accID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
for _, id := range tc.expectedDeleted {
|
for _, id := range tc.expectedDeleted {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) {
|
|||||||
PeerID: "test",
|
PeerID: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ SUDO=""
|
|||||||
|
|
||||||
if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then
|
if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then
|
||||||
SUDO="sudo"
|
SUDO="sudo"
|
||||||
|
elif command -v doas > /dev/null && [ "$(id -u)" -ne 0 ]; then
|
||||||
|
SUDO="doas"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z ${NETBIRD_RELEASE+x} ]; then
|
if [ -z ${NETBIRD_RELEASE+x} ]; then
|
||||||
@@ -68,7 +70,7 @@ download_release_binary() {
|
|||||||
if [ -n "$GITHUB_TOKEN" ]; then
|
if [ -n "$GITHUB_TOKEN" ]; then
|
||||||
cd /tmp && curl -H "Authorization: token ${GITHUB_TOKEN}" -LO "$DOWNLOAD_URL"
|
cd /tmp && curl -H "Authorization: token ${GITHUB_TOKEN}" -LO "$DOWNLOAD_URL"
|
||||||
else
|
else
|
||||||
cd /tmp && curl -LO "$DOWNLOAD_URL"
|
cd /tmp && curl -LO "$DOWNLOAD_URL" || curl -LO --dns-servers 8.8.8.8 "$DOWNLOAD_URL"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
@@ -316,7 +318,7 @@ install_netbird() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
version_greater_equal() {
|
version_greater_equal() {
|
||||||
printf '%s\n%s\n' "$2" "$1" | sort -V -C
|
printf '%s\n%s\n' "$2" "$1" | sort -V -c
|
||||||
}
|
}
|
||||||
|
|
||||||
is_bin_package_manager() {
|
is_bin_package_manager() {
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ func startSignal() (*grpc.Server, net.Listener) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
srv, err := server.NewServer(otel.Meter(""))
|
srv, err := server.NewServer(context.Background(), otel.Meter(""))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ var (
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
srv, err := server.NewServer(metricsServer.Meter)
|
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating signal server: %v", err)
|
return fmt.Errorf("creating signal server: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,13 +47,13 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Signal server
|
// NewServer creates a new Signal server
|
||||||
func NewServer(meter metric.Meter) (*Server, error) {
|
func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
|
||||||
appMetrics, err := metrics.NewAppMetrics(meter)
|
appMetrics, err := metrics.NewAppMetrics(meter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating app metrics: %v", err)
|
return nil, fmt.Errorf("creating app metrics: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatcher, err := dispatcher.NewDispatcher()
|
dispatcher, err := dispatcher.NewDispatcher(ctx, meter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating dispatcher: %v", err)
|
return nil, fmt.Errorf("creating dispatcher: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user