mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-27 04:36:37 +00:00
Compare commits
44 Commits
update-gom
...
sync-clien
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72513d7522 | ||
|
|
a1f1bf1f19 | ||
|
|
b5dec3df39 | ||
|
|
447cd287f5 | ||
|
|
5748bdd64e | ||
|
|
08f31fbcb3 | ||
|
|
932c02eaab | ||
|
|
abcbde26f9 | ||
|
|
90e3b8009f | ||
|
|
94d34dc0c5 | ||
|
|
44851e06fb | ||
|
|
3f4f825ec1 | ||
|
|
f538e6e9ae | ||
|
|
cb6b086164 | ||
|
|
71b6855e09 | ||
|
|
9bdc4908fb | ||
|
|
031ab11178 | ||
|
|
d2e48d4f5e | ||
|
|
27dd97c9c4 | ||
|
|
e87b4ace11 | ||
|
|
a232cf614c | ||
|
|
a293f760af | ||
|
|
10e9cf8c62 | ||
|
|
7193bd2da7 | ||
|
|
52948ccd61 | ||
|
|
4b77359042 | ||
|
|
387d43bcc1 | ||
|
|
e47d815dd2 | ||
|
|
cb83b7c0d3 | ||
|
|
ddcd182859 | ||
|
|
aca0398105 | ||
|
|
02200d790b | ||
|
|
f31bba87b4 | ||
|
|
7285fef0f0 | ||
|
|
20973063d8 | ||
|
|
ba2e9b6d88 | ||
|
|
131d7a3694 | ||
|
|
290fe2d8b9 | ||
|
|
7fb1a2fe31 | ||
|
|
20f5f00635 | ||
|
|
fc141cf3a3 | ||
|
|
d0c65fa08e | ||
|
|
f241bfa339 | ||
|
|
4b2cd97d5f |
11
.githooks/pre-push
Executable file
11
.githooks/pre-push
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "Running pre-push hook..."
|
||||||
|
if ! make lint; then
|
||||||
|
echo ""
|
||||||
|
echo "Hint: To push without verification, run:"
|
||||||
|
echo " git push --no-verify"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "All checks passed!"
|
||||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -15,13 +15,14 @@ jobs:
|
|||||||
name: "Client / Unit"
|
name: "Client / Unit"
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
|||||||
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
release: "14.2"
|
release: "14.2"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y curl pkgconf xorg
|
pkg install -y curl pkgconf xorg
|
||||||
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
|
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
curl -vLO "$GO_URL"
|
curl -vLO "$GO_URL"
|
||||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||||
|
|||||||
71
.github/workflows/golang-test-linux.yml
vendored
71
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
@@ -106,15 +106,15 @@ jobs:
|
|||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -151,15 +151,15 @@ jobs:
|
|||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
id: go-env
|
id: go-env
|
||||||
run: |
|
run: |
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||||
-e CONTAINER=${CONTAINER} \
|
-e CONTAINER=${CONTAINER} \
|
||||||
golang:1.23-alpine \
|
golang:1.24-alpine \
|
||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
@@ -220,15 +220,15 @@ jobs:
|
|||||||
raceFlag: "-race"
|
raceFlag: "-race"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
@@ -270,15 +270,15 @@ jobs:
|
|||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
@@ -321,15 +321,15 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -408,15 +408,16 @@ jobs:
|
|||||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||||
-p 9090:9090 \
|
-p 9090:9090 \
|
||||||
prom/prometheus
|
prom/prometheus
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version: "1.23.x"
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -497,15 +498,15 @@ jobs:
|
|||||||
-p 9090:9090 \
|
-p 9090:9090 \
|
||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -561,15 +562,15 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres']
|
store: [ 'sqlite', 'postgres']
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
|||||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Setup Android SDK
|
- name: Setup Android SDK
|
||||||
uses: android-actions/setup-android@v3
|
uses: android-actions/setup-android@v3
|
||||||
with:
|
with:
|
||||||
@@ -39,7 +39,7 @@ jobs:
|
|||||||
- name: Setup NDK
|
- name: Setup NDK
|
||||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build android netbird lib
|
- name: build android netbird lib
|
||||||
@@ -56,9 +56,9 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||||
- name: gomobile init
|
- name: gomobile init
|
||||||
run: gomobile init
|
run: gomobile init
|
||||||
- name: build iOS netbird lib
|
- name: build iOS netbird lib
|
||||||
|
|||||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -20,7 +20,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-latest-m
|
||||||
env:
|
env:
|
||||||
flags: ""
|
flags: ""
|
||||||
steps:
|
steps:
|
||||||
@@ -40,7 +40,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -136,7 +136,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
|||||||
@@ -67,10 +67,13 @@ jobs:
|
|||||||
- name: Install curl
|
- name: Install curl
|
||||||
run: sudo apt-get install -y curl
|
run: sudo apt-get install -y curl
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
@@ -80,9 +83,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Setup MySQL privileges
|
- name: Setup MySQL privileges
|
||||||
if: matrix.store == 'mysql'
|
if: matrix.store == 'mysql'
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: Install golangci-lint
|
- name: Install golangci-lint
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version-file: "go.mod"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||||
env:
|
env:
|
||||||
@@ -60,8 +60,8 @@ jobs:
|
|||||||
|
|
||||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||||
|
|
||||||
if [ ${SIZE} -gt 52428800 ]; then
|
if [ ${SIZE} -gt 57671680 ]; then
|
||||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -136,6 +136,14 @@ checked out and set up:
|
|||||||
go mod tidy
|
go mod tidy
|
||||||
```
|
```
|
||||||
|
|
||||||
|
6. Configure Git hooks for automatic linting:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make setup-hooks
|
||||||
|
```
|
||||||
|
|
||||||
|
This will configure Git to run linting automatically before each push, helping catch issues early.
|
||||||
|
|
||||||
### Dev Container Support
|
### Dev Container Support
|
||||||
|
|
||||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||||
|
|||||||
27
Makefile
Normal file
27
Makefile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
.PHONY: lint lint-all lint-install setup-hooks
|
||||||
|
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||||
|
|
||||||
|
# Install golangci-lint locally if needed
|
||||||
|
$(GOLANGCI_LINT):
|
||||||
|
@echo "Installing golangci-lint..."
|
||||||
|
@mkdir -p ./bin
|
||||||
|
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
|
|
||||||
|
# Lint only changed files (fast, for pre-push)
|
||||||
|
lint: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on changed files..."
|
||||||
|
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
|
||||||
|
|
||||||
|
# Lint entire codebase (slow, matches CI)
|
||||||
|
lint-all: $(GOLANGCI_LINT)
|
||||||
|
@echo "Running lint on all files..."
|
||||||
|
@$(GOLANGCI_LINT) run --timeout=12m
|
||||||
|
|
||||||
|
# Just install the linter
|
||||||
|
lint-install: $(GOLANGCI_LINT)
|
||||||
|
|
||||||
|
# Setup git hooks for all developers
|
||||||
|
setup-hooks:
|
||||||
|
@git config core.hooksPath .githooks
|
||||||
|
@chmod +x .githooks/pre-push
|
||||||
|
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||||
@@ -92,7 +92,7 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||||
exportEnvList(envList)
|
exportEnvList(envList)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
@@ -115,7 +115,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
c.ctxCancelLock.Unlock()
|
c.ctxCancelLock.Unlock()
|
||||||
|
|
||||||
auth := NewAuthWithConfig(ctx, cfg)
|
auth := NewAuthWithConfig(ctx, cfg)
|
||||||
err = auth.login(urlOpener)
|
err = auth.login(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ type ErrListener interface {
|
|||||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||||
// the backend want to show an url for the user
|
// the backend want to show an url for the user
|
||||||
type URLOpener interface {
|
type URLOpener interface {
|
||||||
Open(string)
|
Open(url string, userCode string)
|
||||||
OnLoginSuccess()
|
OnLoginSuccess()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Login try register the client on the server
|
// Login try register the client on the server
|
||||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
|
||||||
go func() {
|
go func() {
|
||||||
err := a.login(urlOpener)
|
err := a.login(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resultListener.OnError(err)
|
resultListener.OnError(err)
|
||||||
} else {
|
} else {
|
||||||
@@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener) error {
|
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||||
var needsLogin bool
|
var needsLogin bool
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
@@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
|
|||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||||
|
|||||||
@@ -4,14 +4,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
@@ -332,7 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
hint = profileState.Email
|
hint = profileState.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
if !noBrowser {
|
if !noBrowser {
|
||||||
if err := openBrowser(verificationURIComplete); err != nil {
|
if err := util.OpenBrowser(verificationURIComplete); err != nil {
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
|
||||||
func openBrowser(url string) error {
|
|
||||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
|
||||||
return exec.Command(browser, url).Start()
|
|
||||||
}
|
|
||||||
return open.Run(url)
|
|
||||||
}
|
|
||||||
|
|
||||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isUnixRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ var (
|
|||||||
identityFile string
|
identityFile string
|
||||||
skipCachedToken bool
|
skipCachedToken bool
|
||||||
requestPTY bool
|
requestPTY bool
|
||||||
|
sshNoBrowser bool
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -81,6 +82,7 @@ func init() {
|
|||||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||||
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
|
||||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||||
@@ -185,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string {
|
|||||||
return defaultValue
|
return defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
|
||||||
|
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
|
||||||
|
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||||
|
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||||
|
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
// resetSSHGlobals sets SSH globals to their default values
|
// resetSSHGlobals sets SSH globals to their default values
|
||||||
func resetSSHGlobals() {
|
func resetSSHGlobals() {
|
||||||
port = sshserver.DefaultSSHPort
|
port = sshserver.DefaultSSHPort
|
||||||
@@ -196,6 +213,7 @@ func resetSSHGlobals() {
|
|||||||
strictHostKeyChecking = true
|
strictHostKeyChecking = true
|
||||||
knownHostsFile = ""
|
knownHostsFile = ""
|
||||||
identityFile = ""
|
identityFile = ""
|
||||||
|
sshNoBrowser = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||||
@@ -370,6 +388,7 @@ type sshFlags struct {
|
|||||||
KnownHostsFile string
|
KnownHostsFile string
|
||||||
IdentityFile string
|
IdentityFile string
|
||||||
SkipCachedToken bool
|
SkipCachedToken bool
|
||||||
|
NoBrowser bool
|
||||||
ConfigPath string
|
ConfigPath string
|
||||||
LogLevel string
|
LogLevel string
|
||||||
LocalForwards []string
|
LocalForwards []string
|
||||||
@@ -381,6 +400,7 @@ type sshFlags struct {
|
|||||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||||
|
|
||||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||||
fs.SetOutput(nil)
|
fs.SetOutput(nil)
|
||||||
@@ -401,6 +421,7 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
|||||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||||
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
|
||||||
|
|
||||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||||
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||||
@@ -449,6 +470,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
|||||||
knownHostsFile = flags.KnownHostsFile
|
knownHostsFile = flags.KnownHostsFile
|
||||||
identityFile = flags.IdentityFile
|
identityFile = flags.IdentityFile
|
||||||
skipCachedToken = flags.SkipCachedToken
|
skipCachedToken = flags.SkipCachedToken
|
||||||
|
sshNoBrowser = flags.NoBrowser
|
||||||
|
|
||||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||||
configPath = flags.ConfigPath
|
configPath = flags.ConfigPath
|
||||||
@@ -508,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
|||||||
DaemonAddr: daemonAddr,
|
DaemonAddr: daemonAddr,
|
||||||
SkipCachedToken: skipCachedToken,
|
SkipCachedToken: skipCachedToken,
|
||||||
InsecureSkipVerify: !strictHostKeyChecking,
|
InsecureSkipVerify: !strictHostKeyChecking,
|
||||||
|
NoBrowser: sshNoBrowser,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -763,7 +786,15 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("invalid port: %s", portStr)
|
return fmt.Errorf("invalid port: %s", portStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
|
||||||
|
// where command-line flags cannot be passed. Default is to open browser.
|
||||||
|
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||||
|
var browserOpener func(string) error
|
||||||
|
if !noBrowser {
|
||||||
|
browserOpener = util.OpenBrowser
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create SSH proxy: %w", err)
|
return fmt.Errorf("create SSH proxy: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
@@ -24,8 +26,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -116,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config)
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
|
tableMangle = "mangle"
|
||||||
|
tableRaw = "raw"
|
||||||
|
tableSecurity = "security"
|
||||||
|
|
||||||
chainNameNatPrerouting = "PREROUTING"
|
chainNameNatPrerouting = "PREROUTING"
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
@@ -91,11 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
|||||||
var err error
|
var err error
|
||||||
r.filterTable, err = r.loadFilterTable()
|
r.filterTable, err = r.loadFilterTable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errFilterTableNotFound) {
|
log.Debugf("ip filter table not found: %v", err)
|
||||||
log.Warnf("table 'filter' not found for forward rules")
|
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("load filter table: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
@@ -175,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
|
|||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to list tables: %v", err)
|
return nil, fmt.Errorf("list tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
@@ -187,14 +187,39 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
|||||||
return nil, errFilterTableNotFound
|
return nil, errFilterTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hookName(hook *nftables.ChainHook) string {
|
||||||
|
if hook == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
switch *hook {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
return chainNameForward
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
return chainNameInput
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("hook(%d)", *hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func familyName(family nftables.TableFamily) string {
|
||||||
|
switch family {
|
||||||
|
case nftables.TableFamilyIPv4:
|
||||||
|
return "ip"
|
||||||
|
case nftables.TableFamilyIPv6:
|
||||||
|
return "ip6"
|
||||||
|
case nftables.TableFamilyINet:
|
||||||
|
return "inet"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("family(%d)", family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
func (r *router) createContainers() error {
|
||||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingFw,
|
Name: chainNameRoutingFw,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
})
|
})
|
||||||
|
|
||||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
|
||||||
|
|
||||||
prio := *nftables.ChainPriorityNATSource - 1
|
prio := *nftables.ChainPriorityNATSource - 1
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingNat,
|
Name: chainNameRoutingNat,
|
||||||
@@ -236,9 +261,12 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||||
if err := r.addPostroutingRules(); err != nil {
|
|
||||||
return fmt.Errorf("add single nat rule: %v", err)
|
r.addPostroutingRules()
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("initialize tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.addMSSClampingRules(); err != nil {
|
if err := r.addMSSClampingRules(); err != nil {
|
||||||
@@ -250,11 +278,7 @@ func (r *router) createContainers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("initialize tables: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -695,7 +719,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addPostroutingRules adds the masquerade rules
|
// addPostroutingRules adds the masquerade rules
|
||||||
func (r *router) addPostroutingRules() error {
|
func (r *router) addPostroutingRules() {
|
||||||
// First masquerade rule for traffic coming in from WireGuard interface
|
// First masquerade rule for traffic coming in from WireGuard interface
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
// Match on the first fwmark
|
// Match on the first fwmark
|
||||||
@@ -761,8 +785,6 @@ func (r *router) addPostroutingRules() error {
|
|||||||
Chain: r.chains[chainNameRoutingNat],
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
Exprs: exprs2,
|
Exprs: exprs2,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
@@ -839,7 +861,7 @@ func (r *router) addMSSClampingRules() error {
|
|||||||
Exprs: exprsOut,
|
Exprs: exprsOut,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return r.conn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
@@ -939,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||||
func (r *router) acceptForwardRules() error {
|
func (r *router) acceptForwardRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.acceptFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.acceptExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -953,11 +988,11 @@ func (r *router) acceptForwardRules() error {
|
|||||||
// Try iptables first and fallback to nftables if iptables is not available
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// filter table exists but iptables is not
|
// iptables is not available but the filter table exists
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
fw = "nftables"
|
fw = "nftables"
|
||||||
return r.acceptFilterRulesNftables()
|
return r.acceptFilterRulesNftables(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.acceptFilterRulesIptables(ipt)
|
return r.acceptFilterRulesIptables(ipt)
|
||||||
@@ -968,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables forward rule: %v", rule)
|
log.Debugf("added iptables forward rule: %v", rule)
|
||||||
}
|
}
|
||||||
@@ -976,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("added iptables input rule: %v", inputRule)
|
log.Debugf("added iptables input rule: %v", inputRule)
|
||||||
}
|
}
|
||||||
@@ -996,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
|
|||||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) acceptFilterRulesNftables() error {
|
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||||
|
// This is used when iptables is not available.
|
||||||
|
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
iifRule := &nftables.Rule{
|
forwardChain := &nftables.Chain{
|
||||||
Table: r.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: chainNameForward,
|
Name: chainNameForward,
|
||||||
Table: r.filterTable,
|
Table: table,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
Hooknum: nftables.ChainHookForward,
|
Hooknum: nftables.ChainHookForward,
|
||||||
Priority: nftables.ChainPriorityFilter,
|
Priority: nftables.ChainPriorityFilter,
|
||||||
},
|
}
|
||||||
|
r.insertForwardAcceptRules(forwardChain, intf)
|
||||||
|
|
||||||
|
inputChain := &nftables.Chain{
|
||||||
|
Name: chainNameInput,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookInput,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertInputAcceptRule(inputChain, intf)
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||||
|
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||||
|
func (r *router) acceptExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||||
|
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||||
|
|
||||||
|
switch *chain.Hooknum {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
r.insertForwardAcceptRules(chain, intf)
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
r.insertInputAcceptRule(chain, intf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush external chain rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||||
|
iifRule := &nftables.Rule{
|
||||||
|
Table: chain.Table,
|
||||||
|
Chain: chain,
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1030,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
Data: intf,
|
Data: intf,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
oifRule := &nftables.Rule{
|
oifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameForward,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(oifRule)
|
r.conn.InsertRule(oifRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||||
inputRule := &nftables.Rule{
|
inputRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: chain.Table,
|
||||||
Chain: &nftables.Chain{
|
Chain: chain,
|
||||||
Name: chainNameInput,
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookInput,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@@ -1067,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
|
|||||||
UserData: []byte(userDataAcceptInputRule),
|
UserData: []byte(userDataAcceptInputRule),
|
||||||
}
|
}
|
||||||
r.conn.InsertRule(inputRule)
|
r.conn.InsertRule(inputRule)
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRules() error {
|
func (r *router) removeAcceptFilterRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeFilterTableRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipt, err := iptables.New()
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||||
return r.removeAcceptFilterRulesNftables()
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.removeAcceptFilterRulesIptables(ipt)
|
return r.removeAcceptFilterRulesIptables(ipt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list chains: %v", err)
|
return fmt.Errorf("list chains: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
if chain.Table.Name != r.filterTable.Name {
|
if chain.Table.Name != table.Name {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1100,9 +1188,18 @@ func (r *router) removeAcceptFilterRulesNftables() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get rules: %v", err)
|
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
@@ -1110,31 +1207,96 @@ func (r *router) removeAcceptFilterRulesNftables() error {
|
|||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
return fmt.Errorf("delete rule: %v", err)
|
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||||
|
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||||
|
// ensuring cleanup works even after a crash or if chains changed.
|
||||||
|
func (r *router) removeExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||||
|
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||||
|
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||||
|
func (r *router) findExternalChains() []*nftables.Chain {
|
||||||
|
var chains []*nftables.Chain
|
||||||
|
|
||||||
|
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||||
|
|
||||||
|
for _, family := range families {
|
||||||
|
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("list chains for family %d: %v", family, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range allChains {
|
||||||
|
if r.isExternalChain(chain) {
|
||||||
|
chains = append(chains, chain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return chains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||||
|
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip all iptables-managed tables in the ip family
|
||||||
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Type != nftables.ChainTypeFilter {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesTable(name string) bool {
|
||||||
|
switch name {
|
||||||
|
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
inputRule := r.getAcceptInputRule()
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
@@ -1196,7 +1358,7 @@ func (r *router) refreshRulesMap() error {
|
|||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(" unable to list rules: %v", err)
|
return fmt.Errorf("list rules: %w", err)
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
|
|||||||
@@ -60,14 +60,19 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
return t.AccessToken
|
return t.AccessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
|
||||||
|
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
|
||||||
|
}
|
||||||
|
|
||||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||||
//
|
//
|
||||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
|
||||||
|
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
"github.com/netbirdio/netbird/client/internal/templates"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||||
@@ -46,9 +48,10 @@ type PKCEAuthorizationFlow struct {
|
|||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||||
var availableRedirectURL string
|
var availableRedirectURL string
|
||||||
|
|
||||||
// find the first available redirect URL
|
excludedRanges := getSystemExcludedPortRanges()
|
||||||
|
|
||||||
for _, redirectURL := range config.RedirectURLs {
|
for _, redirectURL := range config.RedirectURLs {
|
||||||
if !isRedirectURLPortUsed(redirectURL) {
|
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
|
||||||
availableRedirectURL = redirectURL
|
availableRedirectURL = redirectURL
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -102,10 +105,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
}
|
}
|
||||||
if !p.providerConfig.DisablePromptLogin {
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
switch p.providerConfig.LoginFlag {
|
||||||
|
case common.LoginFlagPromptLogin:
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
}
|
case common.LoginFlagMaxAge0:
|
||||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
|
||||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -282,15 +285,22 @@ func createCodeChallenge(codeVerifier string) string {
|
|||||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
|
||||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
|
||||||
parsedURL, err := url.Parse(redirectURL)
|
parsedURL, err := url.Parse(redirectURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse redirect URL: %v", err)
|
log.Errorf("failed to parse redirect URL: %v", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
port := parsedURL.Port()
|
||||||
|
|
||||||
|
if isPortInExcludedRange(port, excludedRanges) {
|
||||||
|
log.Warnf("port %s is in Windows excluded port range, skipping", port)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf(":%s", port)
|
||||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
@@ -304,6 +314,33 @@ func isRedirectURLPortUsed(redirectURL string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// excludedPortRange represents a range of excluded ports.
|
||||||
|
type excludedPortRange struct {
|
||||||
|
start int
|
||||||
|
end int
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
|
||||||
|
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
|
||||||
|
if len(excludedRanges) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
portNum, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("invalid port number %s: %v", port, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range excludedRanges {
|
||||||
|
if portNum >= r.start && portNum <= r.end {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
8
client/internal/auth/pkce_flow_other.go
Normal file
8
client/internal/auth/pkce_flow_other.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
// getSystemExcludedPortRanges returns nil on non-Windows platforms.
|
||||||
|
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -2,8 +2,11 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@@ -20,22 +23,28 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
loginFlag mgm.LoginFlag
|
loginFlag mgm.LoginFlag
|
||||||
disablePromptLogin bool
|
disablePromptLogin bool
|
||||||
expect string
|
expectContains []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Prompt login",
|
name: "Prompt login",
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
loginFlag: mgm.LoginFlagPromptLogin,
|
||||||
expect: promptLogin,
|
expectContains: []string{promptLogin},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Max age 0 login",
|
name: "Max age 0",
|
||||||
loginFlag: mgm.LoginFlagMaxAge0,
|
loginFlag: mgm.LoginFlagMaxAge0,
|
||||||
expect: maxAge0,
|
expectContains: []string{maxAge0},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disable prompt login",
|
name: "Disable prompt login",
|
||||||
loginFlag: mgm.LoginFlagPrompt,
|
loginFlag: mgm.LoginFlagPromptLogin,
|
||||||
disablePromptLogin: true,
|
disablePromptLogin: true,
|
||||||
|
expectContains: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "None flag should not add parameters",
|
||||||
|
loginFlag: mgm.LoginFlagNone,
|
||||||
|
expectContains: []string{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,6 +59,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
UseIDToken: true,
|
UseIDToken: true,
|
||||||
LoginFlag: tc.loginFlag,
|
LoginFlag: tc.loginFlag,
|
||||||
|
DisablePromptLogin: tc.disablePromptLogin,
|
||||||
}
|
}
|
||||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -60,12 +70,153 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
t.Fatalf("Failed to request auth info: %v", err)
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tc.disablePromptLogin {
|
for _, expected := range tc.expectContains {
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
require.Contains(t, authInfo.VerificationURIComplete, expected)
|
||||||
} else {
|
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
|
||||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsPortInExcludedRange(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
port string
|
||||||
|
excludedRanges []excludedPortRange
|
||||||
|
expectedBlocked bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Port in excluded range",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port at start of range",
|
||||||
|
port: "8000",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port at end of range",
|
||||||
|
port: "8100",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port before range",
|
||||||
|
port: "7999",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port after range",
|
||||||
|
port: "8101",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty excluded ranges",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: []excludedPortRange{},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil excluded ranges",
|
||||||
|
port: "8080",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple ranges - port in second range",
|
||||||
|
port: "9050",
|
||||||
|
excludedRanges: []excludedPortRange{
|
||||||
|
{start: 8000, end: 8100},
|
||||||
|
{start: 9000, end: 9100},
|
||||||
|
},
|
||||||
|
expectedBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple ranges - port not in any range",
|
||||||
|
port: "8500",
|
||||||
|
excludedRanges: []excludedPortRange{
|
||||||
|
{start: 8000, end: 8100},
|
||||||
|
{start: 9000, end: 9100},
|
||||||
|
},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid port string",
|
||||||
|
port: "invalid",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty port string",
|
||||||
|
port: "",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedBlocked: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
|
||||||
|
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsRedirectURLPortUsed(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = listener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
usedPort := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
redirectURL string
|
||||||
|
excludedRanges []excludedPortRange
|
||||||
|
expectedUsed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Port in excluded range",
|
||||||
|
redirectURL: "http://127.0.0.1:8080/",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port actually in use",
|
||||||
|
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port not in use and not excluded",
|
||||||
|
redirectURL: "http://127.0.0.1:65432/",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid URL without port",
|
||||||
|
redirectURL: "not-a-valid-url",
|
||||||
|
excludedRanges: nil,
|
||||||
|
expectedUsed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port excluded even if not in use",
|
||||||
|
redirectURL: "http://127.0.0.1:8050/",
|
||||||
|
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||||
|
expectedUsed: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
|
||||||
|
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
86
client/internal/auth/pkce_flow_windows.go
Normal file
86
client/internal/auth/pkce_flow_windows.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh.
|
||||||
|
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||||
|
ranges, err := getExcludedPortRangesFromNetsh()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get Windows excluded port ranges: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ranges
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command.
|
||||||
|
func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) {
|
||||||
|
cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp")
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("netsh command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseExcludedPortRanges(string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExcludedPortRanges parses the output of the netsh command to extract port ranges.
|
||||||
|
func parseExcludedPortRanges(output string) ([]excludedPortRange, error) {
|
||||||
|
var ranges []excludedPortRange
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||||
|
|
||||||
|
foundHeader := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
|
||||||
|
if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") {
|
||||||
|
foundHeader = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundHeader {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(line, "----------") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
startPort, err := strconv.Atoi(fields[0])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
endPort, err := strconv.Atoi(fields[1])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ranges = append(ranges, excludedPortRange{start: startPort, end: endPort})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan output: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ranges, nil
|
||||||
|
}
|
||||||
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseExcludedPortRanges(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
netshOutput string
|
||||||
|
expectedRanges []excludedPortRange
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid netsh output with multiple ranges",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Dynamic Port Range
|
||||||
|
---------------------------------
|
||||||
|
Start Port : 49152
|
||||||
|
Number of Ports : 16384
|
||||||
|
|
||||||
|
Protocol tcp Excluded Port Ranges
|
||||||
|
---------------------------------
|
||||||
|
Start Port End Port
|
||||||
|
---------- --------
|
||||||
|
5357 5357 *
|
||||||
|
50000 50059 *
|
||||||
|
`,
|
||||||
|
expectedRanges: []excludedPortRange{
|
||||||
|
{start: 5357, end: 5357},
|
||||||
|
{start: 50000, end: 50059},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty output",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Dynamic Port Range
|
||||||
|
---------------------------------
|
||||||
|
Start Port : 49152
|
||||||
|
Number of Ports : 16384
|
||||||
|
`,
|
||||||
|
expectedRanges: nil,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single range",
|
||||||
|
netshOutput: `
|
||||||
|
Protocol tcp Excluded Port Ranges
|
||||||
|
---------------------------------
|
||||||
|
Start Port End Port
|
||||||
|
---------- --------
|
||||||
|
8080 8090
|
||||||
|
`,
|
||||||
|
expectedRanges: []excludedPortRange{
|
||||||
|
{start: 8080, end: 8090},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ranges, err := parseExcludedPortRanges(tt.netshOutput)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedRanges, ranges)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||||
|
ranges := getSystemExcludedPortRanges()
|
||||||
|
t.Logf("Found %d excluded port ranges on this system", len(ranges))
|
||||||
|
|
||||||
|
listener1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = listener1.Close()
|
||||||
|
}()
|
||||||
|
usedPort1 := listener1.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
availablePort := 65432
|
||||||
|
|
||||||
|
config := internal.PKCEAuthProviderConfig{
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Audience: "test-audience",
|
||||||
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
Scope: "openid email profile",
|
||||||
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
|
RedirectURLs: []string{
|
||||||
|
fmt.Sprintf("http://127.0.0.1:%d/", usedPort1),
|
||||||
|
fmt.Sprintf("http://127.0.0.1:%d/", availablePort),
|
||||||
|
},
|
||||||
|
UseIDToken: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewPKCEAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, flow)
|
||||||
|
assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort),
|
||||||
|
"Should skip port in use and select available port")
|
||||||
|
}
|
||||||
@@ -273,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
|
c.engine = engine
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
@@ -293,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
engine := c.engine
|
|
||||||
c.engine = nil
|
c.engine = nil
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if engine != nil && engine.wgInterface != nil {
|
// todo: consider to remove this condition. Is not thread safe.
|
||||||
|
// We should always call Stop(), but we need to verify that it is idempotent
|
||||||
|
if engine.wgInterface != nil {
|
||||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||||
|
|
||||||
if err := engine.Stop(); err != nil {
|
if err := engine.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
|||||||
var records []nbdns.SimpleRecord
|
var records []nbdns.SimpleRecord
|
||||||
|
|
||||||
for _, zone := range config.CustomZones {
|
for _, zone := range config.CustomZones {
|
||||||
|
if zone.SkipPTRProcess {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
if record.Type != int(dns.TypeA) {
|
if record.Type != int(dns.TypeA) {
|
||||||
continue
|
continue
|
||||||
@@ -108,6 +111,7 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
|||||||
reverseZone := nbdns.CustomZone{
|
reverseZone := nbdns.CustomZone{
|
||||||
Domain: zoneName,
|
Domain: zoneName,
|
||||||
Records: records,
|
Records: records,
|
||||||
|
SearchDomainDisabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||||
|
|||||||
@@ -11,11 +11,6 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4ReverseZone = ".in-addr.arpa."
|
|
||||||
ipv6ReverseZone = ".ip6.arpa."
|
|
||||||
)
|
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
@@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||||
MatchOnly: matchOnly,
|
MatchOnly: customZone.SearchDomainDisabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
|||||||
timeoutMsg += " " + peerInfo
|
timeoutMsg += " " + peerInfo
|
||||||
}
|
}
|
||||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||||
logger.Warnf(timeoutMsg)
|
logger.Warn(timeoutMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||||
|
|||||||
@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||||
|
for i, ip := range ips {
|
||||||
|
ips[i] = ip.Unmap()
|
||||||
|
}
|
||||||
|
|
||||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
f.cache.set(domain, question.Qtype, ips)
|
f.cache.set(domain, question.Qtype, ips)
|
||||||
|
|||||||
@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
if e.connMgr != nil {
|
if e.connMgr != nil {
|
||||||
e.connMgr.Close()
|
e.connMgr.Close()
|
||||||
@@ -298,9 +297,6 @@ func (e *Engine) Stop() error {
|
|||||||
|
|
||||||
e.cleanupSSHConfig()
|
e.cleanupSSHConfig()
|
||||||
|
|
||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
|
||||||
e.stopDNSServer()
|
|
||||||
|
|
||||||
if e.ingressGatewayMgr != nil {
|
if e.ingressGatewayMgr != nil {
|
||||||
if err := e.ingressGatewayMgr.Close(); err != nil {
|
if err := e.ingressGatewayMgr.Close(); err != nil {
|
||||||
log.Warnf("failed to cleanup forward rules: %v", err)
|
log.Warnf("failed to cleanup forward rules: %v", err)
|
||||||
@@ -308,24 +304,29 @@ func (e *Engine) Stop() error {
|
|||||||
e.ingressGatewayMgr = nil
|
e.ingressGatewayMgr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
e.stopDNSForwarder()
|
|
||||||
|
|
||||||
if e.routeManager != nil {
|
|
||||||
e.routeManager.Stop(e.stateManager)
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.srWatcher != nil {
|
if e.srWatcher != nil {
|
||||||
e.srWatcher.Close()
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info("cleaning up status recorder states")
|
||||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||||
|
|
||||||
if err := e.removeAllPeers(); err != nil {
|
if err := e.removeAllPeers(); err != nil {
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
log.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.routeManager != nil {
|
||||||
|
e.routeManager.Stop(e.stateManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.stopDNSForwarder()
|
||||||
|
|
||||||
|
// stop/restore DNS after peers are closed but before interface goes down
|
||||||
|
// so dbus and friends don't complain because of a missing interface
|
||||||
|
e.stopDNSServer()
|
||||||
|
|
||||||
if e.cancel != nil {
|
if e.cancel != nil {
|
||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
@@ -337,16 +338,18 @@ func (e *Engine) Stop() error {
|
|||||||
e.flowManager.Close()
|
e.flowManager.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
defer cancel()
|
defer stateCancel()
|
||||||
|
|
||||||
if err := e.stateManager.Stop(ctx); err != nil {
|
if err := e.stateManager.Stop(stateCtx); err != nil {
|
||||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
log.Errorf("failed to stop state manager: %v", err)
|
||||||
}
|
}
|
||||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
timeout := e.calculateShutdownTimeout()
|
timeout := e.calculateShutdownTimeout()
|
||||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
@@ -432,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create rosenpass manager: %w", err)
|
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
err := e.rpManager.Run()
|
if err := e.rpManager.Run(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -485,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := e.createFirewall(); err != nil {
|
if err := e.createFirewall(); err != nil {
|
||||||
|
e.close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -750,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
if update.GetNetbirdConfig() != nil {
|
if update.GetNetbirdConfig() != nil {
|
||||||
wCfg := update.GetNetbirdConfig()
|
wCfg := update.GetNetbirdConfig()
|
||||||
err := e.updateTURNs(wCfg.GetTurns())
|
err := e.updateTURNs(wCfg.GetTurns())
|
||||||
@@ -789,7 +797,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nm := update.GetNetworkMap()
|
nm := update.GetNetworkMap()
|
||||||
if nm == nil {
|
if nm == nil || update.SkipNetworkMapUpdate {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -955,7 +963,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.networkSerial, e.handleSync)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
@@ -1208,6 +1216,8 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
|||||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||||
dnsZone := nbdns.CustomZone{
|
dnsZone := nbdns.CustomZone{
|
||||||
Domain: zone.GetDomain(),
|
Domain: zone.GetDomain(),
|
||||||
|
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
|
||||||
|
SkipPTRProcess: zone.GetSkipPTRProcess(),
|
||||||
}
|
}
|
||||||
for _, record := range zone.Records {
|
for _, record := range zone.Records {
|
||||||
dnsRecord := nbdns.SimpleRecord{
|
dnsRecord := nbdns.SimpleRecord{
|
||||||
@@ -1367,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
|
if e.ctx.Err() != nil {
|
||||||
|
return e.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
|
|||||||
79
client/internal/engine_sync_test.go
Normal file
79
client/internal/engine_sync_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensures handleSync exits early when SkipNetworkMapUpdate is true
|
||||||
|
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
|
||||||
|
WgIfaceName: "utun199",
|
||||||
|
WgAddr: "100.70.0.1/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
// Precondition
|
||||||
|
if engine.networkSerial != 0 {
|
||||||
|
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &mgmtProto.SyncResponse{
|
||||||
|
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
|
||||||
|
SkipNetworkMapUpdate: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := engine.handleSync(resp); err != nil {
|
||||||
|
t.Fatalf("handleSync returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if engine.networkSerial != 0 {
|
||||||
|
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensures handleSync exits early when NetworkMap is nil
|
||||||
|
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
|
||||||
|
WgIfaceName: "utun198",
|
||||||
|
WgAddr: "100.70.0.2/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33101,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
resp := &mgmtProto.SyncResponse{NetworkMap: nil}
|
||||||
|
|
||||||
|
if err := engine.handleSync(resp); err != nil {
|
||||||
|
t.Fatalf("handleSync returned error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -30,11 +30,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -54,7 +55,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -631,7 +631,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
// feed updates to Engine via mocked Management client
|
// feed updates to Engine via mocked Management client
|
||||||
updates := make(chan *mgmtProto.SyncResponse)
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
defer close(updates)
|
defer close(updates)
|
||||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
syncFunc := func(ctx context.Context, info *system.Info, networkSerial uint64, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
for msg := range updates {
|
for msg := range updates {
|
||||||
err := msgHandler(msg)
|
err := msgHandler(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1628,14 +1628,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
218
client/internal/sleep/detector_darwin.go
Normal file
218
client/internal/sleep/detector_darwin.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package sleep
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||||
|
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||||
|
#include <IOKit/IOMessage.h>
|
||||||
|
#include <CoreFoundation/CoreFoundation.h>
|
||||||
|
|
||||||
|
extern void sleepCallbackBridge();
|
||||||
|
extern void poweredOnCallbackBridge();
|
||||||
|
extern void suspendedCallbackBridge();
|
||||||
|
extern void resumedCallbackBridge();
|
||||||
|
|
||||||
|
|
||||||
|
// C global variables for IOKit state
|
||||||
|
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||||
|
static io_object_t g_notifierObject = 0;
|
||||||
|
static io_object_t g_generalInterestNotifier = 0;
|
||||||
|
static io_connect_t g_rootPort = 0;
|
||||||
|
static CFRunLoopRef g_runLoop = NULL;
|
||||||
|
|
||||||
|
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||||
|
switch (messageType) {
|
||||||
|
case kIOMessageSystemWillSleep:
|
||||||
|
sleepCallbackBridge();
|
||||||
|
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||||
|
break;
|
||||||
|
case kIOMessageSystemHasPoweredOn:
|
||||||
|
poweredOnCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsSuspended:
|
||||||
|
suspendedCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsResumed:
|
||||||
|
resumedCallbackBridge();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void registerNotifications() {
|
||||||
|
g_rootPort = IORegisterForSystemPower(
|
||||||
|
NULL,
|
||||||
|
&g_notifyPortRef,
|
||||||
|
(IOServiceInterestCallback)sleepCallback,
|
||||||
|
&g_notifierObject
|
||||||
|
);
|
||||||
|
|
||||||
|
if (g_rootPort == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
g_runLoop = CFRunLoopGetCurrent();
|
||||||
|
CFRunLoopRun();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unregisterNotifications() {
|
||||||
|
CFRunLoopRemoveSource(g_runLoop,
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
IODeregisterForSystemPower(&g_notifierObject);
|
||||||
|
IOServiceClose(g_rootPort);
|
||||||
|
IONotificationPortDestroy(g_notifyPortRef);
|
||||||
|
CFRunLoopStop(g_runLoop);
|
||||||
|
|
||||||
|
g_notifyPortRef = NULL;
|
||||||
|
g_notifierObject = 0;
|
||||||
|
g_rootPort = 0;
|
||||||
|
g_runLoop = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
serviceRegistry = make(map[*Detector]struct{})
|
||||||
|
serviceRegistryMu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
//export sleepCallbackBridge
|
||||||
|
func sleepCallbackBridge() {
|
||||||
|
log.Info("sleepCallbackBridge event triggered")
|
||||||
|
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeSleep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//export resumedCallbackBridge
|
||||||
|
func resumedCallbackBridge() {
|
||||||
|
log.Info("resumedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export suspendedCallbackBridge
|
||||||
|
func suspendedCallbackBridge() {
|
||||||
|
log.Info("suspendedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export poweredOnCallbackBridge
|
||||||
|
func poweredOnCallbackBridge() {
|
||||||
|
log.Info("poweredOnCallbackBridge event triggered")
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeWakeUp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Detector struct {
|
||||||
|
callback func(event EventType)
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDetector() (*Detector, error) {
|
||||||
|
return &Detector{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) Register(callback func(event EventType)) error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := serviceRegistry[d]; exists {
|
||||||
|
return fmt.Errorf("detector service already registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
d.callback = callback
|
||||||
|
|
||||||
|
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
|
|
||||||
|
// CFRunLoop must run on a single fixed OS thread
|
||||||
|
go func() {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
|
||||||
|
C.registerNotifications()
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Info("sleep detection service started on macOS")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||||
|
// and the runloop is stopped and cleaned up.
|
||||||
|
func (d *Detector) Deregister() error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
_, exists := serviceRegistry[d]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancel and remove this detector
|
||||||
|
d.cancel()
|
||||||
|
delete(serviceRegistry, d)
|
||||||
|
|
||||||
|
// If other Detectors still exist, leave IOKit running
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("sleep detection service stopping (deregister)")
|
||||||
|
|
||||||
|
// Deregister IOKit notifications, stop runloop, and free resources
|
||||||
|
C.unregisterNotifications()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) triggerCallback(event EventType) {
|
||||||
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
|
timeout := time.NewTimer(500 * time.Millisecond)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
cb := d.callback
|
||||||
|
go func(callback func(event EventType)) {
|
||||||
|
log.Info("sleep detection event fired")
|
||||||
|
callback(event)
|
||||||
|
close(doneChan)
|
||||||
|
}(cb)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-doneChan:
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
case <-timeout.C:
|
||||||
|
log.Warnf("sleep callback timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
9
client/internal/sleep/detector_notsupported.go
Normal file
9
client/internal/sleep/detector_notsupported.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !darwin || ios
|
||||||
|
|
||||||
|
package sleep
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func NewDetector() (detector, error) {
|
||||||
|
return nil, fmt.Errorf("sleep not supported on this platform")
|
||||||
|
}
|
||||||
37
client/internal/sleep/service.go
Normal file
37
client/internal/sleep/service.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package sleep
|
||||||
|
|
||||||
|
var (
|
||||||
|
EventTypeUnknown EventType = 0
|
||||||
|
EventTypeSleep EventType = 1
|
||||||
|
EventTypeWakeUp EventType = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type EventType int
|
||||||
|
|
||||||
|
type detector interface {
|
||||||
|
Register(callback func(eventType EventType)) error
|
||||||
|
Deregister() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
detector detector
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() (*Service, error) {
|
||||||
|
d, err := NewDetector()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{
|
||||||
|
detector: d,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Register(callback func(eventType EventType)) error {
|
||||||
|
return s.detector.Register(callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Deregister() error {
|
||||||
|
return s.detector.Deregister()
|
||||||
|
}
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(fd int32, interfaceName string) error {
|
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||||
|
exportEnvList(envList)
|
||||||
log.Infof("Starting NetBird client")
|
log.Infof("Starting NetBird client")
|
||||||
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
@@ -228,7 +232,7 @@ func (c *Client) LoginForMobile() string {
|
|||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
|
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
|
|||||||
}
|
}
|
||||||
return netIDs
|
return netIDs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func exportEnvList(list *EnvList) {
|
||||||
|
if list == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for k, v := range list.AllItems() {
|
||||||
|
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
|
||||||
|
log.Debugf("Setting env variable %s: %s", k, v)
|
||||||
|
|
||||||
|
if err := os.Setenv(k, v); err != nil {
|
||||||
|
log.Errorf("could not set env variable %s: %v", k, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Env variable %s was set successfully", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
34
client/ios/NetBirdSDK/env_list.go
Normal file
34
client/ios/NetBirdSDK/env_list.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package NetBirdSDK
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
|
||||||
|
// EnvList is an exported struct to be bound by gomobile
|
||||||
|
type EnvList struct {
|
||||||
|
data map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEnvList creates a new EnvList
|
||||||
|
func NewEnvList() *EnvList {
|
||||||
|
return &EnvList{data: make(map[string]string)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put adds a key-value pair
|
||||||
|
func (el *EnvList) Put(key, value string) {
|
||||||
|
el.data[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value by key
|
||||||
|
func (el *EnvList) Get(key string) string {
|
||||||
|
return el.data[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el *EnvList) AllItems() map[string]string {
|
||||||
|
return el.data
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
|
||||||
|
func GetEnvKeyNBForceRelay() string {
|
||||||
|
return peer.EnvKeyNBForceRelay
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import _ "golang.org/x/mobile/bind"
|
import _ "golang.org/x/mobile/bind"
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package NetBirdSDK
|
package NetBirdSDK
|
||||||
|
|
||||||
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
|
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@ service DaemonService {
|
|||||||
// Status of the service.
|
// Status of the service.
|
||||||
rpc Status(StatusRequest) returns (StatusResponse) {}
|
rpc Status(StatusRequest) returns (StatusResponse) {}
|
||||||
|
|
||||||
// Down engine work in the daemon.
|
// Down stops engine work in the daemon.
|
||||||
rpc Down(DownRequest) returns (DownResponse) {}
|
rpc Down(DownRequest) returns (DownResponse) {}
|
||||||
|
|
||||||
// GetConfig of the daemon.
|
// GetConfig of the daemon.
|
||||||
@@ -93,9 +93,26 @@ service DaemonService {
|
|||||||
|
|
||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||||
|
|
||||||
|
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
message OSLifecycleRequest {
|
||||||
|
// avoid collision with loglevel enum
|
||||||
|
enum CycleType {
|
||||||
|
UNKNOWN = 0;
|
||||||
|
SLEEP = 1;
|
||||||
|
WAKEUP = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
CycleType type = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message OSLifecycleResponse {}
|
||||||
|
|
||||||
|
|
||||||
message LoginRequest {
|
message LoginRequest {
|
||||||
// setupKey netbird setup key.
|
// setupKey netbird setup key.
|
||||||
string setupKey = 1;
|
string setupKey = 1;
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ type DaemonServiceClient interface {
|
|||||||
Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error)
|
Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error)
|
||||||
// Status of the service.
|
// Status of the service.
|
||||||
Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error)
|
Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error)
|
||||||
// Down engine work in the daemon.
|
// Down stops engine work in the daemon.
|
||||||
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
|
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
|
||||||
// GetConfig of the daemon.
|
// GetConfig of the daemon.
|
||||||
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
|
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
|
||||||
@@ -70,6 +70,7 @@ type DaemonServiceClient interface {
|
|||||||
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||||
|
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type daemonServiceClient struct {
|
type daemonServiceClient struct {
|
||||||
@@ -382,6 +383,15 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
|
||||||
|
out := new(OSLifecycleResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonServiceServer is the server API for DaemonService service.
|
// DaemonServiceServer is the server API for DaemonService service.
|
||||||
// All implementations must embed UnimplementedDaemonServiceServer
|
// All implementations must embed UnimplementedDaemonServiceServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
@@ -395,7 +405,7 @@ type DaemonServiceServer interface {
|
|||||||
Up(context.Context, *UpRequest) (*UpResponse, error)
|
Up(context.Context, *UpRequest) (*UpResponse, error)
|
||||||
// Status of the service.
|
// Status of the service.
|
||||||
Status(context.Context, *StatusRequest) (*StatusResponse, error)
|
Status(context.Context, *StatusRequest) (*StatusResponse, error)
|
||||||
// Down engine work in the daemon.
|
// Down stops engine work in the daemon.
|
||||||
Down(context.Context, *DownRequest) (*DownResponse, error)
|
Down(context.Context, *DownRequest) (*DownResponse, error)
|
||||||
// GetConfig of the daemon.
|
// GetConfig of the daemon.
|
||||||
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
|
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
|
||||||
@@ -438,6 +448,7 @@ type DaemonServiceServer interface {
|
|||||||
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||||
|
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
||||||
mustEmbedUnimplementedDaemonServiceServer()
|
mustEmbedUnimplementedDaemonServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -538,6 +549,9 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request
|
|||||||
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||||
|
|
||||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
@@ -1112,6 +1126,24 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(OSLifecycleRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/NotifyOSLifecycle",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@@ -1239,6 +1271,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "WaitJWTToken",
|
MethodName: "WaitJWTToken",
|
||||||
Handler: _DaemonService_WaitJWTToken_Handler,
|
Handler: _DaemonService_WaitJWTToken_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "NotifyOSLifecycle",
|
||||||
|
Handler: _DaemonService_NotifyOSLifecycle_Handler,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{
|
Streams: []grpc.StreamDesc{
|
||||||
{
|
{
|
||||||
|
|||||||
77
client/server/lifecycle.go
Normal file
77
client/server/lifecycle.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||||
|
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||||
|
switch req.GetType() {
|
||||||
|
case proto.OSLifecycleRequest_WAKEUP:
|
||||||
|
return s.handleWakeUp(callerCtx)
|
||||||
|
case proto.OSLifecycleRequest_SLEEP:
|
||||||
|
return s.handleSleep(callerCtx)
|
||||||
|
default:
|
||||||
|
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||||
|
}
|
||||||
|
return &proto.OSLifecycleResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
|
||||||
|
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
|
||||||
|
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
||||||
|
if !s.sleepTriggeredDown.Load() {
|
||||||
|
log.Info("skipping up because wasn't sleep down")
|
||||||
|
return &proto.OSLifecycleResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||||
|
s.sleepTriggeredDown.Store(false)
|
||||||
|
|
||||||
|
log.Info("running up after wake up")
|
||||||
|
_, err := s.Up(callerCtx, &proto.UpRequest{})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("running up failed: %v", err)
|
||||||
|
return &proto.OSLifecycleResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("running up command executed successfully")
|
||||||
|
return &proto.OSLifecycleResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
|
||||||
|
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
||||||
|
s.mutex.Lock()
|
||||||
|
|
||||||
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
status, err := state.Status()
|
||||||
|
if err != nil {
|
||||||
|
s.mutex.Unlock()
|
||||||
|
return &proto.OSLifecycleResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||||
|
log.Infof("skipping setting the agent down because status is %s", status)
|
||||||
|
s.mutex.Unlock()
|
||||||
|
return &proto.OSLifecycleResponse{}, nil
|
||||||
|
}
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
log.Info("running down after system started sleeping")
|
||||||
|
|
||||||
|
_, err = s.Down(callerCtx, &proto.DownRequest{})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("running down failed: %v", err)
|
||||||
|
return &proto.OSLifecycleResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sleepTriggeredDown.Store(true)
|
||||||
|
|
||||||
|
log.Info("running down executed successfully")
|
||||||
|
return &proto.OSLifecycleResponse{}, nil
|
||||||
|
}
|
||||||
219
client/server/lifecycle_test.go
Normal file
219
client/server/lifecycle_test.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestServer() *Server {
|
||||||
|
ctx := internal.CtxInitState(context.Background())
|
||||||
|
return &Server{
|
||||||
|
rootCtx: ctx,
|
||||||
|
statusRecorder: peer.NewRecorder(""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
|
||||||
|
// sleepTriggeredDown is false by default
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load())
|
||||||
|
|
||||||
|
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||||
|
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
|
||||||
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
state.Set(internal.StatusIdle)
|
||||||
|
|
||||||
|
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||||
|
Type: proto.OSLifecycleRequest_SLEEP,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
|
||||||
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
state.Set(internal.StatusNeedsLogin)
|
||||||
|
|
||||||
|
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||||
|
Type: proto.OSLifecycleRequest_SLEEP,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
|
||||||
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
state.Set(internal.StatusConnecting)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
s.actCancel = cancel
|
||||||
|
|
||||||
|
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
||||||
|
Type: proto.OSLifecycleRequest_SLEEP,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
||||||
|
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
|
||||||
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
state.Set(internal.StatusConnected)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
s.actCancel = cancel
|
||||||
|
|
||||||
|
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
||||||
|
Type: proto.OSLifecycleRequest_SLEEP,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
||||||
|
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
|
||||||
|
// Manually set the flag to simulate prior sleep down
|
||||||
|
s.sleepTriggeredDown.Store(true)
|
||||||
|
|
||||||
|
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
|
||||||
|
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||||
|
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
|
||||||
|
// First wakeup without prior sleep - should be no-op
|
||||||
|
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||||
|
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load())
|
||||||
|
|
||||||
|
// Simulate prior sleep
|
||||||
|
s.sleepTriggeredDown.Store(true)
|
||||||
|
|
||||||
|
// First wakeup after sleep - should reset flag
|
||||||
|
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||||
|
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||||
|
})
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load())
|
||||||
|
|
||||||
|
// Second wakeup - should be no-op
|
||||||
|
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||||
|
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
|
||||||
|
resp, err := s.handleWakeUp(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
s.sleepTriggeredDown.Store(true)
|
||||||
|
|
||||||
|
// Even if Up fails, flag should be reset
|
||||||
|
_, _ = s.handleWakeUp(context.Background())
|
||||||
|
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
status internal.StatusType
|
||||||
|
}{
|
||||||
|
{"Idle", internal.StatusIdle},
|
||||||
|
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||||
|
{"LoginFailed", internal.StatusLoginFailed},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
state.Set(tt.status)
|
||||||
|
|
||||||
|
resp, err := s.handleSleep(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.False(t, s.sleepTriggeredDown.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
status internal.StatusType
|
||||||
|
}{
|
||||||
|
{"Connecting", internal.StatusConnecting},
|
||||||
|
{"Connected", internal.StatusConnected},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s := newTestServer()
|
||||||
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
state.Set(tt.status)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
s.actCancel = cancel
|
||||||
|
|
||||||
|
resp, err := s.handleSleep(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
assert.True(t, s.sleepTriggeredDown.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -85,6 +85,9 @@ type Server struct {
|
|||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
|
|
||||||
|
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
|
||||||
|
sleepTriggeredDown atomic.Bool
|
||||||
|
|
||||||
jwtCache *jwtCache
|
jwtCache *jwtCache
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -504,7 +507,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
if msg.Hint != nil {
|
if msg.Hint != nil {
|
||||||
hint = *msg.Hint
|
hint = *msg.Hint
|
||||||
}
|
}
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, false, hint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.Set(internal.StatusLoginFailed)
|
state.Set(internal.StatusLoginFailed)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -819,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
|||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
if err := s.cleanupConnection(); err != nil {
|
if err := s.cleanupConnection(); err != nil {
|
||||||
|
// todo review to update the status in case any type of error
|
||||||
log.Errorf("failed to shut down properly: %v", err)
|
log.Errorf("failed to shut down properly: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -911,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
||||||
|
// todo review to update the status in case any type of error
|
||||||
log.Errorf("failed to cleanup connection: %v", err)
|
log.Errorf("failed to cleanup connection: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1235,7 +1240,7 @@ func (s *Server) RequestJWTAuth(
|
|||||||
}
|
}
|
||||||
|
|
||||||
isDesktop := isUnixRunningDesktop()
|
isDesktop := isUnixRunningDesktop()
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint)
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,11 +17,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -35,7 +36,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -316,14 +316,17 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
|
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
|
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -278,6 +279,7 @@ type DialOptions struct {
|
|||||||
DaemonAddr string
|
DaemonAddr string
|
||||||
SkipCachedToken bool
|
SkipCachedToken bool
|
||||||
InsecureSkipVerify bool
|
InsecureSkipVerify bool
|
||||||
|
NoBrowser bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial connects to the given ssh server with specified options
|
// Dial connects to the given ssh server with specified options
|
||||||
@@ -307,7 +309,7 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er
|
|||||||
config.Auth = append(config.Auth, authMethod)
|
config.Auth = append(config.Auth, authMethod)
|
||||||
}
|
}
|
||||||
|
|
||||||
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser)
|
||||||
}
|
}
|
||||||
|
|
||||||
// dialSSH establishes an SSH connection without JWT authentication
|
// dialSSH establishes an SSH connection without JWT authentication
|
||||||
@@ -333,7 +335,7 @@ func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
||||||
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
|
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) {
|
||||||
host, portStr, err := net.SplitHostPort(addr)
|
host, portStr, err := net.SplitHostPort(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
||||||
@@ -359,7 +361,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
|
|||||||
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
|
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request JWT token: %w", err)
|
return nil, fmt.Errorf("request JWT token: %w", err)
|
||||||
}
|
}
|
||||||
@@ -369,7 +371,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// requestJWTToken requests a JWT token from the NetBird daemon
|
// requestJWTToken requests a JWT token from the NetBird daemon
|
||||||
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) {
|
||||||
hint := profilemanager.GetLoginHint()
|
hint := profilemanager.GetLoginHint()
|
||||||
|
|
||||||
conn, err := connectToDaemon(daemonAddr)
|
conn, err := connectToDaemon(daemonAddr)
|
||||||
@@ -379,7 +381,13 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st
|
|||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
|
|
||||||
|
var browserOpener func(string) error
|
||||||
|
if !noBrowser {
|
||||||
|
browserOpener = util.OpenBrowser
|
||||||
|
}
|
||||||
|
|
||||||
|
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||||
|
|||||||
@@ -67,8 +67,31 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe
|
|||||||
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// printAuthInstructions prints authentication instructions to stderr
|
||||||
|
func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) {
|
||||||
|
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
||||||
|
|
||||||
|
if browserWillOpen {
|
||||||
|
_, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.")
|
||||||
|
_, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:")
|
||||||
|
_, _ = fmt.Fprintln(stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete)
|
||||||
|
|
||||||
|
if authResponse.UserCode != "" {
|
||||||
|
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if browserWillOpen {
|
||||||
|
_, _ = fmt.Fprintln(stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
||||||
|
}
|
||||||
|
|
||||||
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
||||||
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) {
|
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) {
|
||||||
req := &proto.RequestJWTAuthRequest{}
|
req := &proto.RequestJWTAuthRequest{}
|
||||||
if hint != "" {
|
if hint != "" {
|
||||||
req.Hint = &hint
|
req.Hint = &hint
|
||||||
@@ -84,12 +107,13 @@ func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdo
|
|||||||
}
|
}
|
||||||
|
|
||||||
if stderr != nil {
|
if stderr != nil {
|
||||||
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
printAuthInstructions(stderr, authResponse, openBrowser != nil)
|
||||||
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
|
}
|
||||||
if authResponse.UserCode != "" {
|
|
||||||
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
if openBrowser != nil {
|
||||||
|
if err := openBrowser(authResponse.VerificationURIComplete); err != nil {
|
||||||
|
log.Debugf("open browser: %v", err)
|
||||||
}
|
}
|
||||||
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
||||||
|
|||||||
@@ -41,9 +41,10 @@ type SSHProxy struct {
|
|||||||
stderr io.Writer
|
stderr io.Writer
|
||||||
conn *grpc.ClientConn
|
conn *grpc.ClientConn
|
||||||
daemonClient proto.DaemonServiceClient
|
daemonClient proto.DaemonServiceClient
|
||||||
|
browserOpener func(string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
|
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
|
||||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -57,6 +58,7 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP
|
|||||||
stderr: stderr,
|
stderr: stderr,
|
||||||
conn: grpcConn,
|
conn: grpcConn,
|
||||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||||
|
browserOpener: browserOpener,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,7 +72,7 @@ func (p *SSHProxy) Close() error {
|
|||||||
func (p *SSHProxy) Connect(ctx context.Context) error {
|
func (p *SSHProxy) Connect(ctx context.Context) error {
|
||||||
hint := profilemanager.GetLoginHint()
|
hint := profilemanager.GetLoginHint()
|
||||||
|
|
||||||
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
|
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint, p.browserOpener)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(jwtAuthErrorMsg, err)
|
return fmt.Errorf(jwtAuthErrorMsg, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ func TestSSHProxy_Connect(t *testing.T) {
|
|||||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
||||||
mockDaemon.setJWTToken(validToken)
|
mockDaemon.setJWTToken(validToken)
|
||||||
|
|
||||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
|
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
clientConn, proxyConn := net.Pipe()
|
clientConn, proxyConn := net.Pipe()
|
||||||
|
|||||||
@@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectUtilLinuxLogin always returns false on JS/WASM
|
||||||
|
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// executeCommandWithPty is not supported on JS/WASM
|
// executeCommandWithPty is not supported on JS/WASM
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||||
logger.Errorf("PTY command execution not supported on JS/WASM")
|
logger.Errorf("PTY command execution not supported on JS/WASM")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool {
|
|||||||
return supported
|
return supported
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils).
|
||||||
|
// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent.
|
||||||
|
// See https://bugs.debian.org/1078023 for details.
|
||||||
|
func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "login", "--version")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("login --version failed (likely shadow-utils): %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
isUtilLinux := strings.Contains(string(output), "util-linux")
|
||||||
|
log.Debugf("util-linux login detected: %v", isUtilLinux)
|
||||||
|
return isUtilLinux
|
||||||
|
}
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching
|
// createSuCommand creates a command using su -l -c for privilege switching
|
||||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||||
suPath, err := exec.LookPath("su")
|
suPath, err := exec.LookPath("su")
|
||||||
@@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("starting interactive shell: %s", execCmd.Path)
|
logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " "))
|
||||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectUtilLinuxLogin always returns false on Windows
|
||||||
|
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||||
command := session.RawCommand()
|
command := session.RawCommand()
|
||||||
|
|||||||
@@ -139,6 +139,7 @@ type Server struct {
|
|||||||
jwtConfig *JWTConfig
|
jwtConfig *JWTConfig
|
||||||
|
|
||||||
suSupportsPty bool
|
suSupportsPty bool
|
||||||
|
loginIsUtilLinux bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type JWTConfig struct {
|
type JWTConfig struct {
|
||||||
@@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.suSupportsPty = s.detectSuPtySupport(ctx)
|
s.suSupportsPty = s.detectSuPtySupport(ctx)
|
||||||
|
s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx)
|
||||||
|
|
||||||
ln, addrDesc, err := s.createListener(ctx, addr)
|
ln, addrDesc, err := s.createListener(ctx, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
|
|||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "linux":
|
case "linux":
|
||||||
// Special handling for Arch Linux without /etc/pam.d/remote
|
p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String())
|
||||||
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
|
return p, a, nil
|
||||||
return loginPath, []string{"-f", username, "-p"}, nil
|
|
||||||
}
|
|
||||||
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
|
|
||||||
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
|
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||||
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
|
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
|
||||||
default:
|
default:
|
||||||
@@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fileExists checks if a file exists (helper for login command logic)
|
// getLinuxLoginCmd returns the login command for Linux systems.
|
||||||
|
// Handles differences between util-linux and shadow-utils login implementations.
|
||||||
|
func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) {
|
||||||
|
// Special handling for Arch Linux without /etc/pam.d/remote
|
||||||
|
var loginArgs []string
|
||||||
|
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
|
||||||
|
loginArgs = []string{"-f", username, "-p"}
|
||||||
|
} else {
|
||||||
|
loginArgs = []string{"-f", username, "-h", remoteIP, "-p"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// util-linux login requires setsid -c to create a new session and set the
|
||||||
|
// controlling terminal. Without this, vhangup() kills the parent process.
|
||||||
|
// See https://bugs.debian.org/1078023 for details.
|
||||||
|
// TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec()
|
||||||
|
// to avoid external setsid dependency.
|
||||||
|
if !s.loginIsUtilLinux {
|
||||||
|
return loginPath, loginArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
setsidPath, err := exec.LookPath("setsid")
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err)
|
||||||
|
return loginPath, loginArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
args := append([]string{"-w", "-c", loginPath}, loginArgs...)
|
||||||
|
return setsidPath, args
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileExists checks if a file exists
|
||||||
func (s *Server) fileExists(path string) bool {
|
func (s *Server) fileExists(path string) bool {
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
return err == nil
|
return err == nil
|
||||||
|
|||||||
@@ -72,7 +72,8 @@ func IsSystemAccount(username string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
|
return strings.HasSuffix(username, "$")
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterTestUserCleanup registers a test user for cleanup
|
// RegisterTestUserCleanup registers a test user for cleanup
|
||||||
|
|||||||
115
client/ssh/testutil/user_helpers_test.go
Normal file
115
client/ssh/testutil/user_helpers_test.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/user"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestUserCurrentBehavior validates user.Current() behavior on Windows.
|
||||||
|
// When running as SYSTEM on a domain-joined machine, user.Current() returns:
|
||||||
|
// - Username: Computer account name (e.g., "DOMAIN\MACHINE$")
|
||||||
|
// - SID: SYSTEM SID (S-1-5-18)
|
||||||
|
func TestUserCurrentBehavior(t *testing.T) {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
t.Skip("Windows-specific test")
|
||||||
|
}
|
||||||
|
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err, "Should be able to get current user")
|
||||||
|
|
||||||
|
t.Logf("Current user - Username: %s, SID: %s", currentUser.Username, currentUser.Uid)
|
||||||
|
|
||||||
|
// When running as SYSTEM, validate expected behavior
|
||||||
|
if currentUser.Uid == "S-1-5-18" {
|
||||||
|
t.Run("SYSTEM_account_behavior", func(t *testing.T) {
|
||||||
|
// SID must be S-1-5-18 for SYSTEM
|
||||||
|
require.Equal(t, "S-1-5-18", currentUser.Uid,
|
||||||
|
"SYSTEM account must have SID S-1-5-18")
|
||||||
|
|
||||||
|
// Username can be either "NT AUTHORITY\SYSTEM" (standalone)
|
||||||
|
// or "DOMAIN\MACHINE$" (domain-joined)
|
||||||
|
username := currentUser.Username
|
||||||
|
isNTAuthority := strings.Contains(strings.ToUpper(username), "NT AUTHORITY")
|
||||||
|
isComputerAccount := strings.HasSuffix(username, "$")
|
||||||
|
|
||||||
|
assert.True(t, isNTAuthority || isComputerAccount,
|
||||||
|
"Username should be either 'NT AUTHORITY\\SYSTEM' or computer account (ending with $), got: %s",
|
||||||
|
username)
|
||||||
|
|
||||||
|
if isComputerAccount {
|
||||||
|
t.Logf("SYSTEM as computer account: %s", username)
|
||||||
|
} else if isNTAuthority {
|
||||||
|
t.Logf("SYSTEM as NT AUTHORITY\\SYSTEM")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that IsSystemAccount correctly identifies system accounts
|
||||||
|
t.Run("IsSystemAccount_validation", func(t *testing.T) {
|
||||||
|
// Test with current user if it's a system account
|
||||||
|
if currentUser.Uid == "S-1-5-18" || // SYSTEM
|
||||||
|
currentUser.Uid == "S-1-5-19" || // LOCAL SERVICE
|
||||||
|
currentUser.Uid == "S-1-5-20" { // NETWORK SERVICE
|
||||||
|
|
||||||
|
result := IsSystemAccount(currentUser.Username)
|
||||||
|
assert.True(t, result,
|
||||||
|
"IsSystemAccount should recognize system account: %s (SID: %s)",
|
||||||
|
currentUser.Username, currentUser.Uid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test explicit cases
|
||||||
|
testCases := []struct {
|
||||||
|
username string
|
||||||
|
expected bool
|
||||||
|
reason string
|
||||||
|
}{
|
||||||
|
{"NT AUTHORITY\\SYSTEM", true, "NT AUTHORITY\\SYSTEM"},
|
||||||
|
{"system", true, "system"},
|
||||||
|
{"SYSTEM", true, "SYSTEM (case insensitive)"},
|
||||||
|
{"NT AUTHORITY\\LOCAL SERVICE", true, "LOCAL SERVICE"},
|
||||||
|
{"NT AUTHORITY\\NETWORK SERVICE", true, "NETWORK SERVICE"},
|
||||||
|
{"DOMAIN\\MACHINE$", true, "computer account (ends with $)"},
|
||||||
|
{"WORKGROUP\\WIN2K19-C2$", true, "computer account (ends with $)"},
|
||||||
|
{"Administrator", false, "Administrator is not a system account"},
|
||||||
|
{"alice", false, "regular user"},
|
||||||
|
{"DOMAIN\\alice", false, "domain user"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.username, func(t *testing.T) {
|
||||||
|
result := IsSystemAccount(tc.username)
|
||||||
|
assert.Equal(t, tc.expected, result,
|
||||||
|
"IsSystemAccount(%q) should be %v because: %s",
|
||||||
|
tc.username, tc.expected, tc.reason)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestComputerAccountDetection validates computer account detection.
|
||||||
|
func TestComputerAccountDetection(t *testing.T) {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
t.Skip("Windows-specific test")
|
||||||
|
}
|
||||||
|
|
||||||
|
computerAccounts := []string{
|
||||||
|
"MACHINE$",
|
||||||
|
"WIN2K19-C2$",
|
||||||
|
"DOMAIN\\MACHINE$",
|
||||||
|
"WORKGROUP\\SERVER$",
|
||||||
|
"server.domain.com$",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, account := range computerAccounts {
|
||||||
|
t.Run(account, func(t *testing.T) {
|
||||||
|
result := IsSystemAccount(account)
|
||||||
|
assert.True(t, result,
|
||||||
|
"Computer account %q should be recognized as system account", account)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -120,6 +120,26 @@ func (i *Info) SetFlags(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *Info) CopyFlagsFrom(other *Info) {
|
||||||
|
i.SetFlags(
|
||||||
|
other.RosenpassEnabled,
|
||||||
|
other.RosenpassPermissive,
|
||||||
|
&other.ServerSSHAllowed,
|
||||||
|
other.DisableClientRoutes,
|
||||||
|
other.DisableServerRoutes,
|
||||||
|
other.DisableDNS,
|
||||||
|
other.DisableFirewall,
|
||||||
|
other.BlockLANAccess,
|
||||||
|
other.BlockInbound,
|
||||||
|
other.LazyConnectionEnabled,
|
||||||
|
&other.EnableSSHRoot,
|
||||||
|
&other.EnableSSHSFTP,
|
||||||
|
&other.EnableSSHLocalPortForwarding,
|
||||||
|
&other.EnableSSHRemotePortForwarding,
|
||||||
|
&other.DisableSSHAuth,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||||
func extractUserAgent(ctx context.Context) string {
|
func extractUserAgent(ctx context.Context) string {
|
||||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||||
|
|||||||
@@ -8,6 +8,90 @@ import (
|
|||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestInfo_CopyFlagsFrom(t *testing.T) {
|
||||||
|
origin := &Info{}
|
||||||
|
serverSSHAllowed := true
|
||||||
|
enableSSHRoot := true
|
||||||
|
enableSSHSFTP := false
|
||||||
|
enableSSHLocalPortForwarding := true
|
||||||
|
enableSSHRemotePortForwarding := false
|
||||||
|
disableSSHAuth := true
|
||||||
|
origin.SetFlags(
|
||||||
|
true, // RosenpassEnabled
|
||||||
|
false, // RosenpassPermissive
|
||||||
|
&serverSSHAllowed,
|
||||||
|
true, // DisableClientRoutes
|
||||||
|
false, // DisableServerRoutes
|
||||||
|
true, // DisableDNS
|
||||||
|
false, // DisableFirewall
|
||||||
|
true, // BlockLANAccess
|
||||||
|
false, // BlockInbound
|
||||||
|
true, // LazyConnectionEnabled
|
||||||
|
&enableSSHRoot,
|
||||||
|
&enableSSHSFTP,
|
||||||
|
&enableSSHLocalPortForwarding,
|
||||||
|
&enableSSHRemotePortForwarding,
|
||||||
|
&disableSSHAuth,
|
||||||
|
)
|
||||||
|
|
||||||
|
got := &Info{}
|
||||||
|
got.CopyFlagsFrom(origin)
|
||||||
|
|
||||||
|
if got.RosenpassEnabled != true {
|
||||||
|
t.Fatalf("RosenpassEnabled not copied: got %v", got.RosenpassEnabled)
|
||||||
|
}
|
||||||
|
if got.RosenpassPermissive != false {
|
||||||
|
t.Fatalf("RosenpassPermissive not copied: got %v", got.RosenpassPermissive)
|
||||||
|
}
|
||||||
|
if got.ServerSSHAllowed != true {
|
||||||
|
t.Fatalf("ServerSSHAllowed not copied: got %v", got.ServerSSHAllowed)
|
||||||
|
}
|
||||||
|
if got.DisableClientRoutes != true {
|
||||||
|
t.Fatalf("DisableClientRoutes not copied: got %v", got.DisableClientRoutes)
|
||||||
|
}
|
||||||
|
if got.DisableServerRoutes != false {
|
||||||
|
t.Fatalf("DisableServerRoutes not copied: got %v", got.DisableServerRoutes)
|
||||||
|
}
|
||||||
|
if got.DisableDNS != true {
|
||||||
|
t.Fatalf("DisableDNS not copied: got %v", got.DisableDNS)
|
||||||
|
}
|
||||||
|
if got.DisableFirewall != false {
|
||||||
|
t.Fatalf("DisableFirewall not copied: got %v", got.DisableFirewall)
|
||||||
|
}
|
||||||
|
if got.BlockLANAccess != true {
|
||||||
|
t.Fatalf("BlockLANAccess not copied: got %v", got.BlockLANAccess)
|
||||||
|
}
|
||||||
|
if got.BlockInbound != false {
|
||||||
|
t.Fatalf("BlockInbound not copied: got %v", got.BlockInbound)
|
||||||
|
}
|
||||||
|
if got.LazyConnectionEnabled != true {
|
||||||
|
t.Fatalf("LazyConnectionEnabled not copied: got %v", got.LazyConnectionEnabled)
|
||||||
|
}
|
||||||
|
if got.EnableSSHRoot != true {
|
||||||
|
t.Fatalf("EnableSSHRoot not copied: got %v", got.EnableSSHRoot)
|
||||||
|
}
|
||||||
|
if got.EnableSSHSFTP != false {
|
||||||
|
t.Fatalf("EnableSSHSFTP not copied: got %v", got.EnableSSHSFTP)
|
||||||
|
}
|
||||||
|
if got.EnableSSHLocalPortForwarding != true {
|
||||||
|
t.Fatalf("EnableSSHLocalPortForwarding not copied: got %v", got.EnableSSHLocalPortForwarding)
|
||||||
|
}
|
||||||
|
if got.EnableSSHRemotePortForwarding != false {
|
||||||
|
t.Fatalf("EnableSSHRemotePortForwarding not copied: got %v", got.EnableSSHRemotePortForwarding)
|
||||||
|
}
|
||||||
|
if got.DisableSSHAuth != true {
|
||||||
|
t.Fatalf("DisableSSHAuth not copied: got %v", got.DisableSSHAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure CopyFlagsFrom does not touch unrelated fields
|
||||||
|
origin.Hostname = "host-a"
|
||||||
|
got.Hostname = "host-b"
|
||||||
|
got.CopyFlagsFrom(origin)
|
||||||
|
if got.Hostname != "host-b" {
|
||||||
|
t.Fatalf("CopyFlagsFrom should not overwrite non-flag fields, got Hostname=%q", got.Hostname)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_LocalWTVersion(t *testing.T) {
|
func Test_LocalWTVersion(t *testing.T) {
|
||||||
got := GetInfo(context.TODO())
|
got := GetInfo(context.TODO())
|
||||||
want := "development"
|
want := "development"
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/sleep"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||||
"github.com/netbirdio/netbird/client/ui/event"
|
"github.com/netbirdio/netbird/client/ui/event"
|
||||||
@@ -213,6 +214,7 @@ type serviceClient struct {
|
|||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
addr string
|
addr string
|
||||||
conn proto.DaemonServiceClient
|
conn proto.DaemonServiceClient
|
||||||
|
connLock sync.Mutex
|
||||||
|
|
||||||
eventHandler *eventHandler
|
eventHandler *eventHandler
|
||||||
|
|
||||||
@@ -1098,6 +1100,9 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
|
|
||||||
go s.eventManager.Start(s.ctx)
|
go s.eventManager.Start(s.ctx)
|
||||||
go s.eventHandler.listen(s.ctx)
|
go s.eventHandler.listen(s.ctx)
|
||||||
|
|
||||||
|
// Start sleep detection listener
|
||||||
|
go s.startSleepListener()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
||||||
@@ -1134,6 +1139,8 @@ func (s *serviceClient) onTrayExit() {
|
|||||||
|
|
||||||
// getSrvClient connection to the service.
|
// getSrvClient connection to the service.
|
||||||
func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) {
|
func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) {
|
||||||
|
s.connLock.Lock()
|
||||||
|
defer s.connLock.Unlock()
|
||||||
if s.conn != nil {
|
if s.conn != nil {
|
||||||
return s.conn, nil
|
return s.conn, nil
|
||||||
}
|
}
|
||||||
@@ -1156,6 +1163,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
|
|||||||
return s.conn, nil
|
return s.conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// startSleepListener initializes the sleep detection service and listens for sleep events
|
||||||
|
func (s *serviceClient) startSleepListener() {
|
||||||
|
sleepService, err := sleep.New()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("%v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sleepService.Register(s.handleSleepEvents); err != nil {
|
||||||
|
log.Errorf("failed to start sleep detection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("sleep detection service initialized")
|
||||||
|
|
||||||
|
// Cleanup on context cancellation
|
||||||
|
go func() {
|
||||||
|
<-s.ctx.Done()
|
||||||
|
log.Info("stopping sleep event listener")
|
||||||
|
if err := sleepService.Deregister(); err != nil {
|
||||||
|
log.Errorf("failed to deregister sleep detection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSleepEvents sends a sleep notification to the daemon via gRPC
|
||||||
|
func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
|
||||||
|
conn, err := s.getSrvClient(0)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get daemon client for sleep notification: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &proto.OSLifecycleRequest{}
|
||||||
|
|
||||||
|
switch event {
|
||||||
|
case sleep.EventTypeWakeUp:
|
||||||
|
log.Infof("handle wakeup event: %v", event)
|
||||||
|
req.Type = proto.OSLifecycleRequest_WAKEUP
|
||||||
|
case sleep.EventTypeSleep:
|
||||||
|
log.Infof("handle sleep event: %v", event)
|
||||||
|
req.Type = proto.OSLifecycleRequest_SLEEP
|
||||||
|
default:
|
||||||
|
log.Infof("unknown event: %v", event)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.NotifyOSLifecycle(s.ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("successfully notified daemon about os lifecycle")
|
||||||
|
}
|
||||||
|
|
||||||
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
||||||
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
||||||
if s.mSettings != nil {
|
if s.mSettings != nil {
|
||||||
|
|||||||
@@ -28,7 +28,8 @@ func IsAnotherProcessRunning() (int32, bool, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
|
runningProcessName := strings.ToLower(filepath.Base(runningProcessPath))
|
||||||
|
if runningProcessName == processName && isProcessOwnedByCurrentUser(p) {
|
||||||
return p.Pid, true, nil
|
return p.Pid, true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,10 @@ type CustomZone struct {
|
|||||||
Domain string
|
Domain string
|
||||||
// Records custom zone records
|
// Records custom zone records
|
||||||
Records []SimpleRecord
|
Records []SimpleRecord
|
||||||
|
// SearchDomainDisabled indicates whether to add match domains to a search domains list or not
|
||||||
|
SearchDomainDisabled bool
|
||||||
|
// SkipPTRProcess indicates whether a client should process PTR records from custom zones
|
||||||
|
SkipPTRProcess bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
|
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
|
||||||
|
|||||||
24
go.mod
24
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module github.com/netbirdio/netbird
|
module github.com/netbirdio/netbird
|
||||||
|
|
||||||
go 1.23.1
|
go 1.24.10
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cunicu.li/go-rosenpass v0.4.0
|
cunicu.li/go-rosenpass v0.4.0
|
||||||
@@ -17,8 +17,8 @@ require (
|
|||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
golang.org/x/crypto v0.41.0
|
golang.org/x/crypto v0.45.0
|
||||||
golang.org/x/sys v0.35.0
|
golang.org/x/sys v0.38.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
@@ -64,7 +64,7 @@ require (
|
|||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
@@ -105,12 +105,12 @@ require (
|
|||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
goauthentik.io/api/v3 v3.2023051.3
|
goauthentik.io/api/v3 v3.2023051.3
|
||||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
|
||||||
golang.org/x/mod v0.26.0
|
golang.org/x/mod v0.30.0
|
||||||
golang.org/x/net v0.42.0
|
golang.org/x/net v0.47.0
|
||||||
golang.org/x/oauth2 v0.30.0
|
golang.org/x/oauth2 v0.30.0
|
||||||
golang.org/x/sync v0.16.0
|
golang.org/x/sync v0.18.0
|
||||||
golang.org/x/term v0.34.0
|
golang.org/x/term v0.37.0
|
||||||
golang.org/x/time v0.12.0
|
golang.org/x/time v0.12.0
|
||||||
google.golang.org/api v0.177.0
|
google.golang.org/api v0.177.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
@@ -251,9 +251,9 @@ require (
|
|||||||
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
golang.org/x/image v0.24.0 // indirect
|
golang.org/x/image v0.33.0 // indirect
|
||||||
golang.org/x/text v0.28.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
golang.org/x/tools v0.35.0 // indirect
|
golang.org/x/tools v0.39.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
|
|||||||
44
go.sum
44
go.sum
@@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
@@ -600,19 +600,19 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
|
|||||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||||
golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ=
|
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
|
||||||
golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8=
|
golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc=
|
||||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg=
|
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab h1:Iqyc+2zr7aGyLuEadIm0KRJP0Wwt+fhlXLa51Fxf1+Q=
|
||||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc=
|
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab/go.mod h1:Eq3Nh/5pFSWug2ohiudJ1iyU59SO78QFuh4qTTN++I0=
|
||||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
@@ -622,8 +622,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
|||||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -647,8 +647,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
|||||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
||||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||||
@@ -665,8 +665,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
|||||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@@ -703,8 +703,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|||||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
@@ -717,8 +717,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
|||||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||||
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
|
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||||
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
|
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
@@ -730,8 +730,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
|||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
@@ -749,8 +749,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
|||||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
@@ -42,6 +43,7 @@ type Controller struct {
|
|||||||
accountManagerMetrics *telemetry.AccountManagerMetrics
|
accountManagerMetrics *telemetry.AccountManagerMetrics
|
||||||
peersUpdateManager network_map.PeersUpdateManager
|
peersUpdateManager network_map.PeersUpdateManager
|
||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
|
EphemeralPeersManager ephemeral.Manager
|
||||||
|
|
||||||
accountUpdateLocks sync.Map
|
accountUpdateLocks sync.Map
|
||||||
sendAccountUpdateLocks sync.Map
|
sendAccountUpdateLocks sync.Map
|
||||||
@@ -70,7 +72,7 @@ type bufferUpdate struct {
|
|||||||
|
|
||||||
var _ network_map.Controller = (*Controller)(nil)
|
var _ network_map.Controller = (*Controller)(nil)
|
||||||
|
|
||||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller {
|
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
|
||||||
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
|
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||||
@@ -100,6 +102,7 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
config: config,
|
config: config,
|
||||||
|
|
||||||
proxyController: proxyController,
|
proxyController: proxyController,
|
||||||
|
EphemeralPeersManager: ephemeralPeersManager,
|
||||||
|
|
||||||
holder: types.NewHolder(),
|
holder: types.NewHolder(),
|
||||||
expNewNetworkMap: newNetworkMapBuilder,
|
expNewNetworkMap: newNetworkMapBuilder,
|
||||||
@@ -107,6 +110,31 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
||||||
|
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.EphemeralPeersManager.OnPeerConnected(ctx, peer)
|
||||||
|
|
||||||
|
return c.peersUpdateManager.CreateChannel(ctx, peerID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) {
|
||||||
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) CountStreams() int {
|
||||||
|
return c.peersUpdateManager.CountStreams()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||||
var (
|
var (
|
||||||
@@ -366,55 +394,26 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
|
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
network, err := c.repo.GetAccountNetwork(ctx, accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peers, err := c.repo.GetAccountPeers(ctx, accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
|
|
||||||
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{
|
|
||||||
RemotePeers: []*proto.RemotePeerConfig{},
|
|
||||||
RemotePeersIsEmpty: true,
|
|
||||||
NetworkMap: &proto.NetworkMap{
|
|
||||||
Serial: network.CurrentSerial(),
|
|
||||||
RemotePeers: []*proto.RemotePeerConfig{},
|
|
||||||
RemotePeersIsEmpty: true,
|
|
||||||
FirewallRules: []*proto.FirewallRule{},
|
|
||||||
FirewallRulesIsEmpty: true,
|
|
||||||
DNSConfig: &proto.DNSConfig{
|
|
||||||
ForwarderPort: dnsFwdPort,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerId)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
|
||||||
if isRequiresApproval {
|
|
||||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isRequiresApproval {
|
||||||
emptyMap := &types.NetworkMap{
|
emptyMap := &types.NetworkMap{
|
||||||
Network: network.Copy(),
|
Network: network.Copy(),
|
||||||
}
|
}
|
||||||
return peer, emptyMap, nil, 0, nil
|
return peer, emptyMap, nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
if clientSerial > 0 && clientSerial == network.CurrentSerial() {
|
||||||
account *types.Account
|
log.WithContext(ctx).Debugf("client serial %d matches current serial, skipping network map calculation", clientSerial)
|
||||||
err error
|
return peer, nil, nil, 0, nil
|
||||||
)
|
}
|
||||||
|
|
||||||
|
var account *types.Account
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
account = c.getAccountFromHolderOrInit(accountID)
|
account = c.getAccountFromHolderOrInit(accountID)
|
||||||
} else {
|
} else {
|
||||||
@@ -698,12 +697,26 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
|
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
c.UpdatePeerInNetworkMapCache(accountId, peer)
|
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
||||||
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get peers by ids: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
|
for _, peer := range peers {
|
||||||
|
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -715,19 +728,53 @@ func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID s
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
|
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
if c.experimentalNetworkMap(accountID) {
|
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
|
||||||
|
peers, err := c.repo.GetAccountPeers(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
RemotePeers: []*proto.RemotePeerConfig{},
|
||||||
|
RemotePeersIsEmpty: true,
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: network.CurrentSerial(),
|
||||||
|
RemotePeers: []*proto.RemotePeerConfig{},
|
||||||
|
RemotePeersIsEmpty: true,
|
||||||
|
FirewallRules: []*proto.FirewallRule{},
|
||||||
|
FirewallRulesIsEmpty: true,
|
||||||
|
DNSConfig: &proto.DNSConfig{
|
||||||
|
ForwarderPort: dnsFwdPort,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
@@ -778,10 +825,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
return networkMap, nil
|
return networkMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
|
||||||
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) IsConnected(peerID string) bool {
|
|
||||||
return c.peersUpdateManager.HasChannel(peerID)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,16 +1,9 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -114,131 +107,3 @@ func TestComputeForwarderPort(t *testing.T) {
|
|||||||
t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result)
|
t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBufferUpdateAccountPeers(t *testing.T) {
|
|
||||||
const (
|
|
||||||
peersCount = 1000
|
|
||||||
updateAccountInterval = 50 * time.Millisecond
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
|
|
||||||
uapLastRun, dpLastRun atomic.Int64
|
|
||||||
|
|
||||||
totalNewRuns, totalOldRuns int
|
|
||||||
)
|
|
||||||
|
|
||||||
uap := func(ctx context.Context, accountID string) {
|
|
||||||
updatePeersDeleted.Store(deletedPeers.Load())
|
|
||||||
updatePeersRuns.Add(1)
|
|
||||||
uapLastRun.Store(time.Now().UnixMilli())
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("new approach", func(t *testing.T) {
|
|
||||||
updatePeersRuns.Store(0)
|
|
||||||
updatePeersDeleted.Store(0)
|
|
||||||
deletedPeers.Store(0)
|
|
||||||
|
|
||||||
var mustore sync.Map
|
|
||||||
bufupd := func(ctx context.Context, accountID string) {
|
|
||||||
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
|
|
||||||
b := mu.(*bufferUpdate)
|
|
||||||
|
|
||||||
if !b.mu.TryLock() {
|
|
||||||
b.update.Store(true)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if b.next != nil {
|
|
||||||
b.next.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer b.mu.Unlock()
|
|
||||||
uap(ctx, accountID)
|
|
||||||
if !b.update.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.update.Store(false)
|
|
||||||
b.next = time.AfterFunc(updateAccountInterval, func() {
|
|
||||||
uap(ctx, accountID)
|
|
||||||
})
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
|
||||||
deletedPeers.Add(1)
|
|
||||||
dpLastRun.Store(time.Now().UnixMilli())
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
bufupd(ctx, accountID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
am := mock_server.MockAccountManager{
|
|
||||||
UpdateAccountPeersFunc: uap,
|
|
||||||
BufferUpdateAccountPeersFunc: bufupd,
|
|
||||||
DeletePeerFunc: dp,
|
|
||||||
}
|
|
||||||
empty := ""
|
|
||||||
for range peersCount {
|
|
||||||
//nolint
|
|
||||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
|
||||||
}
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
|
||||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
|
||||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
|
||||||
|
|
||||||
totalNewRuns = int(updatePeersRuns.Load())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("old approach", func(t *testing.T) {
|
|
||||||
updatePeersRuns.Store(0)
|
|
||||||
updatePeersDeleted.Store(0)
|
|
||||||
deletedPeers.Store(0)
|
|
||||||
|
|
||||||
var mustore sync.Map
|
|
||||||
bufupd := func(ctx context.Context, accountID string) {
|
|
||||||
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
|
|
||||||
b := mu.(*sync.Mutex)
|
|
||||||
|
|
||||||
if !b.TryLock() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
time.Sleep(updateAccountInterval)
|
|
||||||
b.Unlock()
|
|
||||||
uap(ctx, accountID)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
|
||||||
deletedPeers.Add(1)
|
|
||||||
dpLastRun.Store(time.Now().UnixMilli())
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
bufupd(ctx, accountID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
am := mock_server.MockAccountManager{
|
|
||||||
UpdateAccountPeersFunc: uap,
|
|
||||||
BufferUpdateAccountPeersFunc: bufupd,
|
|
||||||
DeletePeerFunc: dp,
|
|
||||||
}
|
|
||||||
empty := ""
|
|
||||||
for range peersCount {
|
|
||||||
//nolint
|
|
||||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
|
||||||
}
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
|
||||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
|
||||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
|
||||||
|
|
||||||
totalOldRuns = int(updatePeersRuns.Load())
|
|
||||||
})
|
|
||||||
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
|
||||||
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ type Repository interface {
|
|||||||
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
|
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
|
||||||
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
|
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
|
||||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||||
|
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||||
|
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type repository struct {
|
type repository struct {
|
||||||
@@ -37,3 +39,11 @@ func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*
|
|||||||
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
|
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
|
||||||
return r.store.GetAccountByPeerID(ctx, peerID)
|
return r.store.GetAccountByPeerID(ctx, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
|
||||||
|
return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
|
||||||
|
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,16 +24,16 @@ type Controller interface {
|
|||||||
UpdateAccountPeers(ctx context.Context, accountID string) error
|
UpdateAccountPeers(ctx context.Context, accountID string) error
|
||||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||||
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
|
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
|
||||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
GetDNSDomain(settings *types.Settings) string
|
GetDNSDomain(settings *types.Settings) string
|
||||||
StartWarmup(context.Context)
|
StartWarmup(context.Context)
|
||||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||||
|
CountStreams() int
|
||||||
|
|
||||||
DeletePeer(ctx context.Context, accountId string, peerId string) error
|
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
|
||||||
|
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
|
||||||
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
|
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
|
||||||
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
|
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
|
||||||
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
|
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
|
||||||
DisconnectPeers(ctx context.Context, peerIDs []string)
|
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
|
||||||
IsConnected(peerID string) bool
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,30 +57,30 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID an
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeer mocks base method.
|
// CountStreams mocks base method.
|
||||||
func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error {
|
func (m *MockController) CountStreams() int {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId)
|
ret := m.ctrl.Call(m, "CountStreams")
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(int)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeer indicates an expected call of DeletePeer.
|
// CountStreams indicates an expected call of CountStreams.
|
||||||
func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisconnectPeers mocks base method.
|
// DisconnectPeers mocks base method.
|
||||||
func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs)
|
m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisconnectPeers indicates an expected call of DisconnectPeers.
|
// DisconnectPeers indicates an expected call of DisconnectPeers.
|
||||||
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDNSDomain mocks base method.
|
// GetDNSDomain mocks base method.
|
||||||
@@ -113,9 +113,9 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetValidatedPeerWithMap mocks base method.
|
// GetValidatedPeerWithMap mocks base method.
|
||||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer, clientSerial uint64) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
|
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p, clientSerial)
|
||||||
ret0, _ := ret[0].(*peer.Peer)
|
ret0, _ := ret[0].(*peer.Peer)
|
||||||
ret1, _ := ret[1].(*types.NetworkMap)
|
ret1, _ := ret[1].(*types.NetworkMap)
|
||||||
ret2, _ := ret[2].([]*posture.Checks)
|
ret2, _ := ret[2].([]*posture.Checks)
|
||||||
@@ -125,63 +125,78 @@ func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequires
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
|
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
|
||||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p, clientSerial any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p, clientSerial)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsConnected mocks base method.
|
// OnPeerConnected mocks base method.
|
||||||
func (m *MockController) IsConnected(peerID string) bool {
|
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "IsConnected", peerID)
|
ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(chan *UpdateMessage)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsConnected indicates an expected call of IsConnected.
|
// OnPeerConnected indicates an expected call of OnPeerConnected.
|
||||||
func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPeerAdded mocks base method.
|
// OnPeerDisconnected mocks base method.
|
||||||
func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error {
|
func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID)
|
m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPeerDisconnected indicates an expected call of OnPeerDisconnected.
|
||||||
|
func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPeersAdded mocks base method.
|
||||||
|
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPeerAdded indicates an expected call of OnPeerAdded.
|
// OnPeersAdded indicates an expected call of OnPeersAdded.
|
||||||
func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPeerDeleted mocks base method.
|
// OnPeersDeleted mocks base method.
|
||||||
func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error {
|
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID)
|
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPeerDeleted indicates an expected call of OnPeerDeleted.
|
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
|
||||||
func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPeerUpdated mocks base method.
|
// OnPeersUpdated mocks base method.
|
||||||
func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) {
|
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "OnPeerUpdated", accountId, peer)
|
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPeerUpdated indicates an expected call of OnPeerUpdated.
|
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
|
||||||
func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartWarmup mocks base method.
|
// StartWarmup mocks base method.
|
||||||
|
|||||||
@@ -2,10 +2,15 @@ package ephemeral
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EphemeralLifeTime = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
LoadInitialPeers(ctx context.Context)
|
LoadInitialPeers(ctx context.Context)
|
||||||
Stop()
|
Stop()
|
||||||
@@ -7,14 +7,15 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ephemeralLifeTime = 10 * time.Minute
|
|
||||||
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
|
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
|
||||||
cleanupWindow = 1 * time.Minute
|
cleanupWindow = 1 * time.Minute
|
||||||
)
|
)
|
||||||
@@ -33,11 +34,11 @@ type ephemeralPeer struct {
|
|||||||
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
|
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
|
||||||
// in worst case we will get invalid error message in this manager.
|
// in worst case we will get invalid error message in this manager.
|
||||||
|
|
||||||
// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted
|
// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
|
||||||
// automatically. Inactivity means the peer disconnected from the Management server.
|
// automatically. Inactivity means the peer disconnected from the Management server.
|
||||||
type EphemeralManager struct {
|
type EphemeralManager struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
accountManager nbAccount.Manager
|
peersManager peers.Manager
|
||||||
|
|
||||||
headPeer *ephemeralPeer
|
headPeer *ephemeralPeer
|
||||||
tailPeer *ephemeralPeer
|
tailPeer *ephemeralPeer
|
||||||
@@ -49,12 +50,12 @@ type EphemeralManager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewEphemeralManager instantiate new EphemeralManager
|
// NewEphemeralManager instantiate new EphemeralManager
|
||||||
func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *EphemeralManager {
|
func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
|
||||||
return &EphemeralManager{
|
return &EphemeralManager{
|
||||||
store: store,
|
store: store,
|
||||||
accountManager: accountManager,
|
peersManager: peersManager,
|
||||||
|
|
||||||
lifeTime: ephemeralLifeTime,
|
lifeTime: ephemeral.EphemeralLifeTime,
|
||||||
cleanupWindow: cleanupWindow,
|
cleanupWindow: cleanupWindow,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -106,7 +107,7 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
|
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
|
||||||
// is inactive it will be deleted after the ephemeralLifeTime period.
|
// is inactive it will be deleted after the EphemeralLifeTime period.
|
||||||
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
|
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
|
||||||
if !peer.Ephemeral {
|
if !peer.Ephemeral {
|
||||||
return
|
return
|
||||||
@@ -180,20 +181,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
|||||||
|
|
||||||
e.peersLock.Unlock()
|
e.peersLock.Unlock()
|
||||||
|
|
||||||
bufferAccountCall := make(map[string]struct{})
|
peerIDsPerAccount := make(map[string][]string)
|
||||||
|
|
||||||
for id, p := range deletePeers {
|
for id, p := range deletePeers {
|
||||||
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
|
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
|
||||||
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
|
}
|
||||||
|
|
||||||
|
for accountID, peerIDs := range peerIDsPerAccount {
|
||||||
|
log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID)
|
||||||
|
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||||
} else {
|
|
||||||
bufferAccountCall[p.accountID] = struct{}{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for accountID := range bufferAccountCall {
|
|
||||||
e.accountManager.BufferUpdateAccountPeers(ctx, accountID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
|
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
|
||||||
@@ -7,10 +7,13 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -91,17 +94,27 @@ func TestNewManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
store := &MockStore{}
|
store := &MockStore{}
|
||||||
am := MockAccountManager{
|
ctrl := gomock.NewController(t)
|
||||||
store: store,
|
peersManager := peers.NewMockManager(ctrl)
|
||||||
}
|
|
||||||
|
|
||||||
numberOfPeers := 5
|
numberOfPeers := 5
|
||||||
numberOfEphemeralPeers := 3
|
numberOfEphemeralPeers := 3
|
||||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||||
|
|
||||||
mgr := NewEphemeralManager(store, &am)
|
// Expect DeletePeers to be called for ephemeral peers
|
||||||
|
peersManager.EXPECT().
|
||||||
|
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
|
||||||
|
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
delete(store.account.Peers, peerID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
mgr := NewEphemeralManager(store, peersManager)
|
||||||
mgr.loadEphemeralPeers(context.Background())
|
mgr.loadEphemeralPeers(context.Background())
|
||||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
|
||||||
mgr.cleanup(context.Background())
|
mgr.cleanup(context.Background())
|
||||||
|
|
||||||
if len(store.account.Peers) != numberOfPeers {
|
if len(store.account.Peers) != numberOfPeers {
|
||||||
@@ -119,19 +132,29 @@ func TestNewManagerPeerConnected(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
store := &MockStore{}
|
store := &MockStore{}
|
||||||
am := MockAccountManager{
|
ctrl := gomock.NewController(t)
|
||||||
store: store,
|
peersManager := peers.NewMockManager(ctrl)
|
||||||
}
|
|
||||||
|
|
||||||
numberOfPeers := 5
|
numberOfPeers := 5
|
||||||
numberOfEphemeralPeers := 3
|
numberOfEphemeralPeers := 3
|
||||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||||
|
|
||||||
mgr := NewEphemeralManager(store, &am)
|
// Expect DeletePeers to be called for ephemeral peers (except the connected one)
|
||||||
|
peersManager.EXPECT().
|
||||||
|
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
|
||||||
|
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
delete(store.account.Peers, peerID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
mgr := NewEphemeralManager(store, peersManager)
|
||||||
mgr.loadEphemeralPeers(context.Background())
|
mgr.loadEphemeralPeers(context.Background())
|
||||||
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||||
|
|
||||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
|
||||||
mgr.cleanup(context.Background())
|
mgr.cleanup(context.Background())
|
||||||
|
|
||||||
expected := numberOfPeers + 1
|
expected := numberOfPeers + 1
|
||||||
@@ -150,15 +173,25 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
store := &MockStore{}
|
store := &MockStore{}
|
||||||
am := MockAccountManager{
|
ctrl := gomock.NewController(t)
|
||||||
store: store,
|
peersManager := peers.NewMockManager(ctrl)
|
||||||
}
|
|
||||||
|
|
||||||
numberOfPeers := 5
|
numberOfPeers := 5
|
||||||
numberOfEphemeralPeers := 3
|
numberOfEphemeralPeers := 3
|
||||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||||
|
|
||||||
mgr := NewEphemeralManager(store, &am)
|
// Expect DeletePeers to be called for the one disconnected peer
|
||||||
|
peersManager.EXPECT().
|
||||||
|
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
|
||||||
|
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
delete(store.account.Peers, peerID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
mgr := NewEphemeralManager(store, peersManager)
|
||||||
mgr.loadEphemeralPeers(context.Background())
|
mgr.loadEphemeralPeers(context.Background())
|
||||||
for _, v := range store.account.Peers {
|
for _, v := range store.account.Peers {
|
||||||
mgr.OnPeerConnected(context.Background(), v)
|
mgr.OnPeerConnected(context.Background(), v)
|
||||||
@@ -166,7 +199,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
|||||||
}
|
}
|
||||||
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||||
|
|
||||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
|
||||||
mgr.cleanup(context.Background())
|
mgr.cleanup(context.Background())
|
||||||
|
|
||||||
expected := numberOfPeers + numberOfEphemeralPeers - 1
|
expected := numberOfPeers + numberOfEphemeralPeers - 1
|
||||||
@@ -181,25 +214,63 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
|
|||||||
testLifeTime = 1 * time.Second
|
testLifeTime = 1 * time.Second
|
||||||
testCleanupWindow = 100 * time.Millisecond
|
testCleanupWindow = 100 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
timeNow = time.Now
|
||||||
|
})
|
||||||
|
startTime := time.Now()
|
||||||
|
timeNow = func() time.Time {
|
||||||
|
return startTime
|
||||||
|
}
|
||||||
|
|
||||||
mockStore := &MockStore{}
|
mockStore := &MockStore{}
|
||||||
|
account := newAccountWithId(context.Background(), "account", "", "", false)
|
||||||
|
mockStore.account = account
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(ephemeralPeers)
|
||||||
mockAM := &MockAccountManager{
|
mockAM := &MockAccountManager{
|
||||||
store: mockStore,
|
store: mockStore,
|
||||||
|
wg: wg,
|
||||||
}
|
}
|
||||||
mockAM.wg = &sync.WaitGroup{}
|
|
||||||
mockAM.wg.Add(ephemeralPeers)
|
ctrl := gomock.NewController(t)
|
||||||
mgr := NewEphemeralManager(mockStore, mockAM)
|
peersManager := peers.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
// Set up expectation that DeletePeers will be called once with all peer IDs
|
||||||
|
peersManager.EXPECT().
|
||||||
|
DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
|
||||||
|
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
|
// Simulate the actual deletion behavior
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mockAM.BufferUpdateAccountPeers(ctx, accountID)
|
||||||
|
return nil
|
||||||
|
}).
|
||||||
|
Times(1)
|
||||||
|
|
||||||
|
mgr := NewEphemeralManager(mockStore, peersManager)
|
||||||
mgr.lifeTime = testLifeTime
|
mgr.lifeTime = testLifeTime
|
||||||
mgr.cleanupWindow = testCleanupWindow
|
mgr.cleanupWindow = testCleanupWindow
|
||||||
|
|
||||||
account := newAccountWithId(context.Background(), "account", "", "", false)
|
// Add peers and disconnect them at slightly different times (within cleanup window)
|
||||||
mockStore.account = account
|
|
||||||
for i := range ephemeralPeers {
|
for i := range ephemeralPeers {
|
||||||
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
|
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
|
||||||
mockStore.account.Peers[p.ID] = p
|
mockStore.account.Peers[p.ID] = p
|
||||||
time.Sleep(testCleanupWindow / ephemeralPeers)
|
|
||||||
mgr.OnPeerDisconnected(context.Background(), p)
|
mgr.OnPeerDisconnected(context.Background(), p)
|
||||||
|
startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
|
||||||
}
|
}
|
||||||
mockAM.wg.Wait()
|
|
||||||
|
// Advance time past the lifetime to trigger cleanup
|
||||||
|
startTime = startTime.Add(testLifeTime + testCleanupWindow)
|
||||||
|
|
||||||
|
// Wait for all deletions to complete
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
|
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
|
||||||
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
|
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
|
||||||
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
|
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
|
||||||
162
management/internals/modules/peers/manager.go
Normal file
162
management/internals/modules/peers/manager.go
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
package peers
|
||||||
|
|
||||||
|
//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
|
||||||
|
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
|
||||||
|
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
|
||||||
|
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
|
||||||
|
DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error
|
||||||
|
SetNetworkMapController(networkMapController network_map.Controller)
|
||||||
|
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
||||||
|
SetAccountManager(accountManager account.Manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
type managerImpl struct {
|
||||||
|
store store.Store
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||||
|
accountManager account.Manager
|
||||||
|
|
||||||
|
networkMapController network_map.Controller
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
|
||||||
|
return &managerImpl{
|
||||||
|
store: store,
|
||||||
|
permissionsManager: permissionsManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) SetNetworkMapController(networkMapController network_map.Controller) {
|
||||||
|
m.networkMapController = networkMapController
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
|
||||||
|
m.integratedPeerValidator = integratedPeerValidator
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
|
||||||
|
m.accountManager = accountManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
|
||||||
|
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||||
|
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
|
||||||
|
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||||
|
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
|
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dnsDomain := m.networkMapController.GetDNSDomain(settings)
|
||||||
|
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
var eventsToStore []func()
|
||||||
|
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := transaction.RemovePeerFromAllGroups(ctx, peerID); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove peer %s from groups", peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, rule := range peerPolicyRules {
|
||||||
|
policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.DeletePeer(ctx, accountID, peerID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, event := range eventsToStore {
|
||||||
|
event()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -9,6 +9,9 @@ import (
|
|||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
network_map "github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
|
account "github.com/netbirdio/netbird/management/server/account"
|
||||||
|
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,6 +38,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
|||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeletePeers mocks base method.
|
||||||
|
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeletePeers", ctx, accountID, peerIDs, userID, checkConnected)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeers indicates an expected call of DeletePeers.
|
||||||
|
func (mr *MockManagerMockRecorder) DeletePeers(ctx, accountID, peerIDs, userID, checkConnected interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeers", reflect.TypeOf((*MockManager)(nil).DeletePeers), ctx, accountID, peerIDs, userID, checkConnected)
|
||||||
|
}
|
||||||
|
|
||||||
// GetAllPeers mocks base method.
|
// GetAllPeers mocks base method.
|
||||||
func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -94,3 +111,39 @@ func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccountManager mocks base method.
|
||||||
|
func (m *MockManager) SetAccountManager(accountManager account.Manager) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetAccountManager", accountManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAccountManager indicates an expected call of SetAccountManager.
|
||||||
|
func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIntegratedPeerValidator mocks base method.
|
||||||
|
func (m *MockManager) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetIntegratedPeerValidator", integratedPeerValidator)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIntegratedPeerValidator indicates an expected call of SetIntegratedPeerValidator.
|
||||||
|
func (mr *MockManagerMockRecorder) SetIntegratedPeerValidator(integratedPeerValidator interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntegratedPeerValidator", reflect.TypeOf((*MockManager)(nil).SetIntegratedPeerValidator), integratedPeerValidator)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNetworkMapController mocks base method.
|
||||||
|
func (m *MockManager) SetNetworkMapController(networkMapController network_map.Controller) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetNetworkMapController", networkMapController)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNetworkMapController indicates an expected call of SetNetworkMapController.
|
||||||
|
func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||||
|
}
|
||||||
@@ -57,7 +57,7 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics {
|
|||||||
|
|
||||||
func (s *BaseServer) Store() store.Store {
|
func (s *BaseServer) Store() store.Store {
|
||||||
return Create(s, func() store.Store {
|
return Create(s, func() store.Store {
|
||||||
store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false)
|
store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create store: %v", err)
|
log.Fatalf("failed to create store: %v", err)
|
||||||
}
|
}
|
||||||
@@ -73,17 +73,17 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
log.Fatalf("failed to initialize integration metrics: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics)
|
eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to initialize event store: %v", err)
|
log.Fatalf("failed to initialize event store: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.config.DataStoreEncryptionKey != key {
|
if s.Config.DataStoreEncryptionKey != key {
|
||||||
log.WithContext(context.Background()).Infof("update config with activity store key")
|
log.WithContext(context.Background()).Infof("update Config with activity store key")
|
||||||
s.config.DataStoreEncryptionKey = key
|
s.Config.DataStoreEncryptionKey = key
|
||||||
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config)
|
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to update config with activity store: %v", err)
|
log.Fatalf("failed to update Config with activity store: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,14 +103,14 @@ func (s *BaseServer) APIHandler() http.Handler {
|
|||||||
|
|
||||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||||
return Create(s, func() *grpc.Server {
|
return Create(s, func() *grpc.Server {
|
||||||
trustedPeers := s.config.ReverseProxy.TrustedPeers
|
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||||
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
|
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
|
||||||
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
||||||
trustedPeers = defaultTrustedPeers
|
trustedPeers = defaultTrustedPeers
|
||||||
}
|
}
|
||||||
trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies
|
trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies
|
||||||
trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount
|
trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount
|
||||||
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
|
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
|
||||||
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
||||||
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
|
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
|
||||||
@@ -128,15 +128,15 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.config.HttpConfig.LetsEncryptDomain != "" {
|
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||||
certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
|
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create certificate manager: %v", err)
|
log.Fatalf("failed to create certificate manager: %v", err)
|
||||||
}
|
}
|
||||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||||
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
|
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
|
||||||
tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
|
tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("cannot load TLS credentials: %v", err)
|
log.Fatalf("cannot load TLS credentials: %v", err)
|
||||||
}
|
}
|
||||||
@@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||||
srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
|
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create management server: %v", err)
|
log.Fatalf("failed to create management server: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,17 +9,17 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||||
return Create(s, func() *update_channel.PeersUpdateManager {
|
return Create(s, func() network_map.PeersUpdateManager {
|
||||||
return update_channel.NewPeersUpdateManager(s.Metrics())
|
return update_channel.NewPeersUpdateManager(s.Metrics())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -44,33 +44,37 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager {
|
func (s *BaseServer) SecretsManager() grpc.SecretsManager {
|
||||||
return Create(s, func() *grpc.TimeBasedAuthSecretsManager {
|
return Create(s, func() grpc.SecretsManager {
|
||||||
return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
|
secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create secrets manager: %v", err)
|
||||||
|
}
|
||||||
|
return secretsManager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) AuthManager() auth.Manager {
|
func (s *BaseServer) AuthManager() auth.Manager {
|
||||||
return Create(s, func() auth.Manager {
|
return Create(s, func() auth.Manager {
|
||||||
return auth.NewManager(s.Store(),
|
return auth.NewManager(s.Store(),
|
||||||
s.config.HttpConfig.AuthIssuer,
|
s.Config.HttpConfig.AuthIssuer,
|
||||||
s.config.HttpConfig.AuthAudience,
|
s.Config.HttpConfig.AuthAudience,
|
||||||
s.config.HttpConfig.AuthKeysLocation,
|
s.Config.HttpConfig.AuthKeysLocation,
|
||||||
s.config.HttpConfig.AuthUserIDClaim,
|
s.Config.HttpConfig.AuthUserIDClaim,
|
||||||
s.config.GetAuthAudiences(),
|
s.Config.GetAuthAudiences(),
|
||||||
s.config.HttpConfig.IdpSignKeyRefreshEnabled)
|
s.Config.HttpConfig.IdpSignKeyRefreshEnabled)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||||
return Create(s, func() ephemeral.Manager {
|
return Create(s, func() ephemeral.Manager {
|
||||||
return manager.NewEphemeralManager(s.Store(), s.AccountManager())
|
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) NetworkMapController() network_map.Controller {
|
func (s *BaseServer) NetworkMapController() network_map.Controller {
|
||||||
return Create(s, func() *nmapcontroller.Controller {
|
return Create(s, func() network_map.Controller {
|
||||||
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController(), s.config)
|
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,3 +83,7 @@ func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
|||||||
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) DNSDomain() string {
|
||||||
|
return s.dnsDomain
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"os"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
@@ -14,20 +16,29 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/networks"
|
"github.com/netbirdio/netbird/management/server/networks"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
geolocationDisabledKey = "NB_DISABLE_GEOLOCATION"
|
||||||
|
)
|
||||||
|
|
||||||
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
||||||
|
if os.Getenv(geolocationDisabledKey) == "true" {
|
||||||
|
log.Info("geolocation service is disabled, skipping initialization")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return Create(s, func() geolocation.Geolocation {
|
return Create(s, func() geolocation.Geolocation {
|
||||||
geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate)
|
geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("could not initialize geolocation service: %v", err)
|
log.Fatalf("could not initialize geolocation service: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("geolocation service has been initialized from %s", s.config.Datadir)
|
log.Infof("geolocation service has been initialized from %s", s.Config.Datadir)
|
||||||
|
|
||||||
return geo
|
return geo
|
||||||
})
|
})
|
||||||
@@ -60,20 +71,22 @@ func (s *BaseServer) SettingsManager() settings.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) PeersManager() peers.Manager {
|
func (s *BaseServer) PeersManager() peers.Manager {
|
||||||
return Create(s, func() peers.Manager {
|
return Create(s, func() peers.Manager {
|
||||||
return peers.NewManager(s.Store(), s.PermissionsManager())
|
manager := peers.NewManager(s.Store(), s.PermissionsManager())
|
||||||
|
s.AfterInit(func(s *BaseServer) {
|
||||||
|
manager.SetNetworkMapController(s.NetworkMapController())
|
||||||
|
manager.SetIntegratedPeerValidator(s.IntegratedValidator())
|
||||||
|
manager.SetAccountManager(s.AccountManager())
|
||||||
|
})
|
||||||
|
return manager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) AccountManager() account.Manager {
|
func (s *BaseServer) AccountManager() account.Manager {
|
||||||
return Create(s, func() account.Manager {
|
return Create(s, func() account.Manager {
|
||||||
accountManager, err := server.BuildManager(context.Background(), s.config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
|
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create account manager: %v", err)
|
log.Fatalf("failed to create account manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.AfterInit(func(s *BaseServer) {
|
|
||||||
accountManager.SetEphemeralManager(s.EphemeralManager())
|
|
||||||
})
|
|
||||||
return accountManager
|
return accountManager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -82,8 +95,8 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
|||||||
return Create(s, func() idp.Manager {
|
return Create(s, func() idp.Manager {
|
||||||
var idpManager idp.Manager
|
var idpManager idp.Manager
|
||||||
var err error
|
var err error
|
||||||
if s.config.IdpManagerConfig != nil {
|
if s.Config.IdpManagerConfig != nil {
|
||||||
idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics())
|
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create IDP manager: %v", err)
|
log.Fatalf("failed to create IDP manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,10 +41,10 @@ type Server interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Server holds the HTTP BaseServer instance.
|
// Server holds the HTTP BaseServer instance.
|
||||||
// Add any additional fields you need, such as database connections, config, etc.
|
// Add any additional fields you need, such as database connections, Config, etc.
|
||||||
type BaseServer struct {
|
type BaseServer struct {
|
||||||
// config holds the server configuration
|
// Config holds the server configuration
|
||||||
config *nbconfig.Config
|
Config *nbconfig.Config
|
||||||
// container of dependencies, each dependency is identified by a unique string.
|
// container of dependencies, each dependency is identified by a unique string.
|
||||||
container map[string]any
|
container map[string]any
|
||||||
// AfterInit is a function that will be called after the server is initialized
|
// AfterInit is a function that will be called after the server is initialized
|
||||||
@@ -70,7 +70,7 @@ type BaseServer struct {
|
|||||||
// NewServer initializes and configures a new Server instance
|
// NewServer initializes and configures a new Server instance
|
||||||
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
|
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
|
||||||
return &BaseServer{
|
return &BaseServer{
|
||||||
config: config,
|
Config: config,
|
||||||
container: make(map[string]any),
|
container: make(map[string]any),
|
||||||
dnsDomain: dnsDomain,
|
dnsDomain: dnsDomain,
|
||||||
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
|
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
|
||||||
@@ -103,14 +103,14 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
|
|
||||||
var tlsConfig *tls.Config
|
var tlsConfig *tls.Config
|
||||||
tlsEnabled := false
|
tlsEnabled := false
|
||||||
if s.config.HttpConfig.LetsEncryptDomain != "" {
|
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||||
s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
|
s.certManager, err = encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
|
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
|
||||||
}
|
}
|
||||||
tlsEnabled = true
|
tlsEnabled = true
|
||||||
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
|
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
|
||||||
tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
|
tlsConfig, err = loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
|
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -126,8 +126,8 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
|
|
||||||
if !s.disableMetrics {
|
if !s.disableMetrics {
|
||||||
idpManager := "disabled"
|
idpManager := "disabled"
|
||||||
if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" {
|
if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" {
|
||||||
idpManager = s.config.IdpManagerConfig.ManagerType
|
idpManager = s.Config.IdpManagerConfig.ManagerType
|
||||||
}
|
}
|
||||||
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
|
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
|
||||||
go metricsWorker.Run(srvCtx)
|
go metricsWorker.Run(srvCtx)
|
||||||
|
|||||||
@@ -104,6 +104,20 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToSkipSyncResponse creates a minimal SyncResponse when the client already has the latest network map.
|
||||||
|
func ToSkipSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, checks []*posture.Checks, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
|
||||||
|
response := &proto.SyncResponse{
|
||||||
|
SkipNetworkMapUpdate: true,
|
||||||
|
Checks: toProtocolChecks(ctx, checks),
|
||||||
|
}
|
||||||
|
|
||||||
|
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||||
|
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||||
|
response.NetbirdConfig = extendedConfig
|
||||||
|
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
|
||||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||||
response := &proto.SyncResponse{
|
response := &proto.SyncResponse{
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
||||||
@@ -369,7 +383,7 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi
|
|||||||
}
|
}
|
||||||
|
|
||||||
issuer := strings.TrimSpace(config.AuthIssuer)
|
issuer := strings.TrimSpace(config.AuthIssuer)
|
||||||
if issuer == "" || deviceFlowConfig != nil {
|
if issuer == "" && deviceFlowConfig != nil {
|
||||||
if d := deriveIssuerFromTokenEndpoint(deviceFlowConfig.ProviderConfig.TokenEndpoint); d != "" {
|
if d := deriveIssuerFromTokenEndpoint(deviceFlowConfig.ProviderConfig.TokenEndpoint); d != "" {
|
||||||
issuer = d
|
issuer = d
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -55,13 +54,10 @@ const (
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
accountManager account.Manager
|
accountManager account.Manager
|
||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
wgKey wgtypes.Key
|
|
||||||
proto.UnimplementedManagementServiceServer
|
proto.UnimplementedManagementServiceServer
|
||||||
peersUpdateManager network_map.PeersUpdateManager
|
|
||||||
config *nbconfig.Config
|
config *nbconfig.Config
|
||||||
secretsManager SecretsManager
|
secretsManager SecretsManager
|
||||||
appMetrics telemetry.AppMetrics
|
appMetrics telemetry.AppMetrics
|
||||||
ephemeralManager ephemeral.Manager
|
|
||||||
peerLocks sync.Map
|
peerLocks sync.Map
|
||||||
authManager auth.Manager
|
authManager auth.Manager
|
||||||
|
|
||||||
@@ -82,23 +78,16 @@ func NewServer(
|
|||||||
config *nbconfig.Config,
|
config *nbconfig.Config,
|
||||||
accountManager account.Manager,
|
accountManager account.Manager,
|
||||||
settingsManager settings.Manager,
|
settingsManager settings.Manager,
|
||||||
peersUpdateManager network_map.PeersUpdateManager,
|
|
||||||
secretsManager SecretsManager,
|
secretsManager SecretsManager,
|
||||||
appMetrics telemetry.AppMetrics,
|
appMetrics telemetry.AppMetrics,
|
||||||
ephemeralManager ephemeral.Manager,
|
|
||||||
authManager auth.Manager,
|
authManager auth.Manager,
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||||
networkMapController network_map.Controller,
|
networkMapController network_map.Controller,
|
||||||
) (*Server, error) {
|
) (*Server, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if appMetrics != nil {
|
if appMetrics != nil {
|
||||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||||
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||||
return int64(peersUpdateManager.CountStreams())
|
return int64(networkMapController.CountStreams())
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -120,16 +109,12 @@ func NewServer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
wgKey: key,
|
|
||||||
// peerKey -> event channel
|
|
||||||
peersUpdateManager: peersUpdateManager,
|
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
settingsManager: settingsManager,
|
settingsManager: settingsManager,
|
||||||
config: config,
|
config: config,
|
||||||
secretsManager: secretsManager,
|
secretsManager: secretsManager,
|
||||||
authManager: authManager,
|
authManager: authManager,
|
||||||
appMetrics: appMetrics,
|
appMetrics: appMetrics,
|
||||||
ephemeralManager: ephemeralManager,
|
|
||||||
logBlockedPeers: logBlockedPeers,
|
logBlockedPeers: logBlockedPeers,
|
||||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||||
integratedPeerValidator: integratedPeerValidator,
|
integratedPeerValidator: integratedPeerValidator,
|
||||||
@@ -149,10 +134,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
|
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
|
||||||
start := time.Now()
|
|
||||||
defer func() {
|
|
||||||
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
|
|
||||||
}()
|
|
||||||
|
|
||||||
// todo introduce something more meaningful with the key expiration/rotation
|
// todo introduce something more meaningful with the key expiration/rotation
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
@@ -163,8 +144,14 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
|
|||||||
nanos := int32(now.Nanosecond())
|
nanos := int32(now.Nanosecond())
|
||||||
expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos}
|
expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos}
|
||||||
|
|
||||||
|
key, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get wireguard key: %v", err)
|
||||||
|
return nil, errors.New("failed to get wireguard key")
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.ServerKeyResponse{
|
return &proto.ServerKeyResponse{
|
||||||
Key: s.wgKey.PublicKey().String(),
|
Key: key.PublicKey().String(),
|
||||||
ExpiresAt: expiresAt,
|
ExpiresAt: expiresAt,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -203,7 +190,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||||
}
|
}
|
||||||
if s.logBlockedPeers {
|
if s.logBlockedPeers {
|
||||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||||
}
|
}
|
||||||
if s.blockPeersWithSameConfig {
|
if s.blockPeersWithSameConfig {
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
@@ -231,8 +218,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
|
||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||||
|
|
||||||
@@ -244,7 +229,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
||||||
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||||
|
|
||||||
@@ -255,7 +239,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
metahash := metaHash(peerMeta, realIP.String())
|
metahash := metaHash(peerMeta, realIP.String())
|
||||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||||
|
|
||||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncReq.GetNetworkMapSerial())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||||
s.syncSem.Add(-1)
|
s.syncSem.Add(-1)
|
||||||
@@ -269,9 +253,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
|
updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID)
|
||||||
|
if err != nil {
|
||||||
s.ephemeralManager.OnPeerConnected(ctx, peer)
|
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||||
|
s.syncSem.Add(-1)
|
||||||
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
|
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
|
||||||
|
|
||||||
@@ -323,13 +311,19 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
|||||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||||
// then sends the encrypted message to the connected peer via the sync server.
|
// then sends the encrypted message to the connected peer via the sync server.
|
||||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
key, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||||
WgPubKey: s.wgKey.PublicKey().String(),
|
WgPubKey: key.PublicKey().String(),
|
||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -348,11 +342,10 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||||
}
|
}
|
||||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
|
||||||
s.secretsManager.CancelRefresh(peer.ID)
|
s.secretsManager.CancelRefresh(peer.ID)
|
||||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||||
@@ -504,7 +497,12 @@ func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage,
|
|||||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
|
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
|
key, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
return wgtypes.Key{}, status.Errorf(codes.Internal, "failed processing request")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = encryption.DecryptMessage(peerKey, key, req.Body, parsed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
|
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
|
||||||
}
|
}
|
||||||
@@ -520,7 +518,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
reqStart := time.Now()
|
reqStart := time.Now()
|
||||||
realIP := getRealIP(ctx)
|
realIP := getRealIP(ctx)
|
||||||
sRealIP := realIP.String()
|
sRealIP := realIP.String()
|
||||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
|
||||||
|
|
||||||
loginReq := &proto.LoginRequest{}
|
loginReq := &proto.LoginRequest{}
|
||||||
peerKey, err := s.parseRequest(ctx, req, loginReq)
|
peerKey, err := s.parseRequest(ctx, req, loginReq)
|
||||||
@@ -532,7 +529,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
metahashed := metaHash(peerMeta, sRealIP)
|
metahashed := metaHash(peerMeta, sRealIP)
|
||||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||||
if s.logBlockedPeers {
|
if s.logBlockedPeers {
|
||||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||||
}
|
}
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
|
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
|
||||||
@@ -556,16 +553,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
//nolint
|
//nolint
|
||||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if s.appMetrics != nil {
|
if s.appMetrics != nil {
|
||||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||||
}
|
}
|
||||||
took := time.Since(reqStart)
|
|
||||||
if took > 7*time.Second {
|
|
||||||
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if loginReq.GetMeta() == nil {
|
if loginReq.GetMeta() == nil {
|
||||||
@@ -599,30 +592,26 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
return nil, mapError(ctx, err)
|
return nil, mapError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
|
|
||||||
|
|
||||||
// if the login request contains setup key then it is a registration request
|
|
||||||
if loginReq.GetSetupKey() != "" {
|
|
||||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
|
||||||
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
|
|
||||||
}
|
|
||||||
|
|
||||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
|
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
|
||||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
|
key, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||||
|
}
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, key, loginResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
|
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
|
||||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &proto.EncryptedMessage{
|
return &proto.EncryptedMessage{
|
||||||
WgPubKey: s.wgKey.PublicKey().String(),
|
WgPubKey: key.PublicKey().String(),
|
||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -713,19 +702,27 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
|||||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
var plainResp *proto.SyncResponse
|
||||||
|
if networkMap == nil {
|
||||||
|
plainResp = ToSkipSyncResponse(ctx, s.config, peer, turnToken, relayToken, postureChecks, settings.Extra, peerGroups)
|
||||||
|
} else {
|
||||||
|
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||||
|
}
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
key, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(codes.Internal, "failed getting server key")
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedResp, err := encryption.EncryptMessage(peerKey, key, plainResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "error handling request")
|
return status.Errorf(codes.Internal, "error handling request")
|
||||||
}
|
}
|
||||||
|
|
||||||
sendStart := time.Now()
|
|
||||||
err = srv.Send(&proto.EncryptedMessage{
|
err = srv.Send(&proto.EncryptedMessage{
|
||||||
WgPubKey: s.wgKey.PublicKey().String(),
|
WgPubKey: key.PublicKey().String(),
|
||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
})
|
})
|
||||||
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
||||||
@@ -740,10 +737,6 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
|||||||
// which will be used by our clients to Login
|
// which will be used by our clients to Login
|
||||||
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||||
start := time.Now()
|
|
||||||
defer func() {
|
|
||||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
|
|
||||||
}()
|
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -752,7 +745,12 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
|
|||||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
|
key, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to get server key")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.DeviceAuthorizationFlowRequest{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||||
log.WithContext(ctx).Warn(errMSG)
|
log.WithContext(ctx).Warn(errMSG)
|
||||||
@@ -782,13 +780,13 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
|
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &proto.EncryptedMessage{
|
return &proto.EncryptedMessage{
|
||||||
WgPubKey: s.wgKey.PublicKey().String(),
|
WgPubKey: key.PublicKey().String(),
|
||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -798,10 +796,6 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
|
|||||||
// which will be used by our clients to Login
|
// which will be used by our clients to Login
|
||||||
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||||
start := time.Now()
|
|
||||||
defer func() {
|
|
||||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
|
|
||||||
}()
|
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -810,7 +804,12 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
|
|||||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
|
key, err := s.secretsManager.GetWGKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to get server key")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.PKCEAuthorizationFlowRequest{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||||
log.WithContext(ctx).Warn(errMSG)
|
log.WithContext(ctx).Warn(errMSG)
|
||||||
@@ -838,13 +837,13 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
|
|||||||
|
|
||||||
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
|
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
|
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &proto.EncryptedMessage{
|
return &proto.EncryptedMessage{
|
||||||
WgPubKey: s.wgKey.PublicKey().String(),
|
WgPubKey: key.PublicKey().String(),
|
||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,15 +73,17 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
mgmtServer := &Server{
|
mgmtServer := &Server{
|
||||||
wgKey: testingServerKey,
|
secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
|
||||||
config: &config.Config{
|
config: &config.Config{
|
||||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
DeviceAuthorizationFlow: testCase.inputFlow,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
||||||
|
key, err := mgmtServer.secretsManager.GetWGKey()
|
||||||
|
require.NoError(t, err, "should be able to get server key")
|
||||||
|
|
||||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
|
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
|
||||||
require.NoError(t, err, "should be able to encrypt message")
|
require.NoError(t, err, "should be able to encrypt message")
|
||||||
|
|
||||||
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
||||||
@@ -95,7 +97,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
if testCase.expectedComparisonFunc != nil {
|
if testCase.expectedComparisonFunc != nil {
|
||||||
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
||||||
|
|
||||||
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
||||||
require.NoError(t, err, "should be able to decrypt")
|
require.NoError(t, err, "should be able to decrypt")
|
||||||
|
|
||||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
@@ -29,6 +30,7 @@ type SecretsManager interface {
|
|||||||
GenerateRelayToken() (*Token, error)
|
GenerateRelayToken() (*Token, error)
|
||||||
SetupRefresh(ctx context.Context, accountID, peerKey string)
|
SetupRefresh(ctx context.Context, accountID, peerKey string)
|
||||||
CancelRefresh(peerKey string)
|
CancelRefresh(peerKey string)
|
||||||
|
GetWGKey() (wgtypes.Key, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
||||||
@@ -43,11 +45,17 @@ type TimeBasedAuthSecretsManager struct {
|
|||||||
groupsManager groups.Manager
|
groupsManager groups.Manager
|
||||||
turnCancelMap map[string]chan struct{}
|
turnCancelMap map[string]chan struct{}
|
||||||
relayCancelMap map[string]chan struct{}
|
relayCancelMap map[string]chan struct{}
|
||||||
|
wgKey wgtypes.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
type Token auth.Token
|
type Token auth.Token
|
||||||
|
|
||||||
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
|
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
mgr := &TimeBasedAuthSecretsManager{
|
mgr := &TimeBasedAuthSecretsManager{
|
||||||
updateManager: updateManager,
|
updateManager: updateManager,
|
||||||
turnCfg: turnCfg,
|
turnCfg: turnCfg,
|
||||||
@@ -56,6 +64,7 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
|
|||||||
relayCancelMap: make(map[string]chan struct{}),
|
relayCancelMap: make(map[string]chan struct{}),
|
||||||
settingsManager: settingsManager,
|
settingsManager: settingsManager,
|
||||||
groupsManager: groupsManager,
|
groupsManager: groupsManager,
|
||||||
|
wgKey: key,
|
||||||
}
|
}
|
||||||
|
|
||||||
if turnCfg != nil {
|
if turnCfg != nil {
|
||||||
@@ -81,7 +90,12 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return mgr
|
return mgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWGKey returns WireGuard private key used to generate peer keys
|
||||||
|
func (m *TimeBasedAuthSecretsManager) GetWGKey() (wgtypes.Key, error) {
|
||||||
|
return m.wgKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateTurnToken generates new time-based secret credentials for TURN
|
// GenerateTurnToken generates new time-based secret credentials for TURN
|
||||||
@@ -153,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI
|
|||||||
relayCancel := make(chan struct{}, 1)
|
relayCancel := make(chan struct{}, 1)
|
||||||
m.relayCancelMap[peerID] = relayCancel
|
m.relayCancelMap[peerID] = relayCancel
|
||||||
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
|
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
|
||||||
log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
|
log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-cancel:
|
case <-cancel:
|
||||||
log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
|
log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
|
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
|
||||||
@@ -179,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-cancel:
|
case <-cancel:
|
||||||
log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
|
log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
m.pushNewRelayTokens(ctx, accountID, peerID)
|
m.pushNewRelayTokens(ctx, accountID, peerID)
|
||||||
|
|||||||
@@ -46,12 +46,13 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
|||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||||
CredentialsTTL: ttl,
|
CredentialsTTL: ttl,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Turns: []*config.Host{TurnTestHost},
|
Turns: []*config.Host{TurnTestHost},
|
||||||
TimeBasedCredentials: true,
|
TimeBasedCredentials: true,
|
||||||
}, rc, settingsMockManager, groupsManager)
|
}, rc, settingsMockManager, groupsManager)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
turnCredentials, err := tested.GenerateTurnToken()
|
turnCredentials, err := tested.GenerateTurnToken()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -98,12 +99,13 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
|||||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||||
CredentialsTTL: ttl,
|
CredentialsTTL: ttl,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Turns: []*config.Host{TurnTestHost},
|
Turns: []*config.Host{TurnTestHost},
|
||||||
TimeBasedCredentials: true,
|
TimeBasedCredentials: true,
|
||||||
}, rc, settingsMockManager, groupsManager)
|
}, rc, settingsMockManager, groupsManager)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -201,12 +203,13 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
|||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||||
CredentialsTTL: ttl,
|
CredentialsTTL: ttl,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Turns: []*config.Host{TurnTestHost},
|
Turns: []*config.Host{TurnTestHost},
|
||||||
TimeBasedCredentials: true,
|
TimeBasedCredentials: true,
|
||||||
}, rc, settingsMockManager, groupsManager)
|
}, rc, settingsMockManager, groupsManager)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
||||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
if _, ok := tested.turnCancelMap[peer]; !ok {
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
@@ -77,7 +76,6 @@ type DefaultAccountManager struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
eventStore activity.Store
|
eventStore activity.Store
|
||||||
geo geolocation.Geolocation
|
geo geolocation.Geolocation
|
||||||
ephemeralManager ephemeral.Manager
|
|
||||||
|
|
||||||
requestBuffer *AccountRequestBuffer
|
requestBuffer *AccountRequestBuffer
|
||||||
|
|
||||||
@@ -238,7 +236,7 @@ func BuildManager(
|
|||||||
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter)
|
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter)
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
|
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting cache store: %s", err)
|
return nil, fmt.Errorf("getting cache store: %s", err)
|
||||||
}
|
}
|
||||||
@@ -263,10 +261,6 @@ func BuildManager(
|
|||||||
return am, nil
|
return am, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
|
|
||||||
am.ephemeralManager = em
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
|
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
|
||||||
return am.externalCacheManager
|
return am.externalCacheManager
|
||||||
}
|
}
|
||||||
@@ -301,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
|
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if oldSettings.Extra != nil && newSettings.Extra != nil &&
|
||||||
|
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
|
||||||
|
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to approve pending peers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if approvedCount > 0 {
|
||||||
|
log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID)
|
||||||
|
updateAccountPeers = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
if oldSettings.NetworkRange != newSettings.NetworkRange {
|
||||||
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -325,6 +332,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newSettings.Extra.IntegratedValidatorGroups = oldSettings.Extra.IntegratedValidatorGroups
|
||||||
|
newSettings.Extra.IntegratedValidator = oldSettings.Extra.IntegratedValidator
|
||||||
|
|
||||||
if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil {
|
if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -375,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return newSettings, nil
|
return newSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||||
halfYearLimit := 180 * 24 * time.Hour
|
halfYearLimit := 180 * 24 * time.Hour
|
||||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||||
@@ -389,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
|
|||||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
|
||||||
for _, peer := range peers {
|
|
||||||
peersMap[peer.ID] = peer
|
|
||||||
}
|
|
||||||
|
|
||||||
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||||
@@ -790,6 +790,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
|||||||
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
||||||
accountIDString := fmt.Sprintf("%v", accountID)
|
accountIDString := fmt.Sprintf("%v", accountID)
|
||||||
|
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:staticcheck
|
||||||
|
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||||
|
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -1610,8 +1617,8 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
|||||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, NetworkMapSerial: clientSerial}, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||||
}
|
}
|
||||||
@@ -2073,7 +2080,10 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
am.networkMapController.OnPeerUpdated(peer.AccountID, peer)
|
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -108,7 +107,7 @@ type Manager interface {
|
|||||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
@@ -124,5 +123,4 @@ type Manager interface {
|
|||||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
||||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||||
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||||
SetEphemeralManager(em ephemeral.Manager)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
@@ -2056,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) {
|
||||||
|
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
accountID := account.Id
|
||||||
|
userID := account.Users[account.CreatedBy].Id
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
newSettings := account.Settings.Copy()
|
||||||
|
newSettings.Extra = &types.ExtraSettings{
|
||||||
|
PeerApprovalEnabled: true,
|
||||||
|
}
|
||||||
|
_, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
peer1.Status.RequiresApproval = true
|
||||||
|
peer2.Status.RequiresApproval = true
|
||||||
|
peer3.Status.RequiresApproval = false
|
||||||
|
|
||||||
|
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1))
|
||||||
|
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2))
|
||||||
|
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3))
|
||||||
|
|
||||||
|
newSettings = account.Settings.Copy()
|
||||||
|
newSettings.Extra = &types.ExtraSettings{
|
||||||
|
PeerApprovalEnabled: false,
|
||||||
|
}
|
||||||
|
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, peer := range accountPeers {
|
||||||
|
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_GetExpiredPeers(t *testing.T) {
|
func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
@@ -2959,8 +2998,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||||
manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -3105,7 +3144,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, 0)
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3371,7 +3410,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("memory cache", func(t *testing.T) {
|
t.Run("memory cache", func(t *testing.T) {
|
||||||
t.Run("should always return true", func(t *testing.T) {
|
t.Run("should always return true", func(t *testing.T) {
|
||||||
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
|
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cold, err := manager.isCacheCold(context.Background(), cacheStore)
|
cold, err := manager.isCacheCold(context.Background(), cacheStore)
|
||||||
@@ -3386,7 +3425,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
|||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
t.Setenv(cache.RedisStoreEnvVar, redisURL)
|
t.Setenv(cache.RedisStoreEnvVar, redisURL)
|
||||||
|
|
||||||
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
|
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("should return true when no account exists", func(t *testing.T) {
|
t.Run("should return true when no account exists", func(t *testing.T) {
|
||||||
|
|||||||
@@ -179,6 +179,7 @@ const (
|
|||||||
PeerIPUpdated Activity = 88
|
PeerIPUpdated Activity = 88
|
||||||
UserApproved Activity = 89
|
UserApproved Activity = 89
|
||||||
UserRejected Activity = 90
|
UserRejected Activity = 90
|
||||||
|
UserCreated Activity = 91
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
@@ -288,6 +289,7 @@ var activityMap = map[Activity]Code{
|
|||||||
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
|
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
|
||||||
UserApproved: {"User approved", "user.approve"},
|
UserApproved: {"User approved", "user.approve"},
|
||||||
UserRejected: {"User rejected", "user.reject"},
|
UserRejected: {"User rejected", "user.reject"},
|
||||||
|
UserCreated: {"User created", "user.create"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
1
management/server/cache/idp.go
vendored
1
management/server/cache/idp.go
vendored
@@ -18,6 +18,7 @@ const (
|
|||||||
DefaultIDPCacheExpirationMax = 7 * 24 * time.Hour // 7 days
|
DefaultIDPCacheExpirationMax = 7 * 24 * time.Hour // 7 days
|
||||||
DefaultIDPCacheExpirationMin = 3 * 24 * time.Hour // 3 days
|
DefaultIDPCacheExpirationMin = 3 * 24 * time.Hour // 3 days
|
||||||
DefaultIDPCacheCleanupInterval = 30 * time.Minute
|
DefaultIDPCacheCleanupInterval = 30 * time.Minute
|
||||||
|
DefaultIDPCacheOpenConn = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserDataCache is an interface that wraps the basic Get, Set and Delete methods for idp.UserData objects.
|
// UserDataCache is an interface that wraps the basic Get, Set and Delete methods for idp.UserData objects.
|
||||||
|
|||||||
2
management/server/cache/idp_test.go
vendored
2
management/server/cache/idp_test.go
vendored
@@ -33,7 +33,7 @@ func TestNewIDPCacheManagers(t *testing.T) {
|
|||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
t.Setenv(cache.RedisStoreEnvVar, redisURL)
|
t.Setenv(cache.RedisStoreEnvVar, redisURL)
|
||||||
}
|
}
|
||||||
cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval)
|
cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval, cache.DefaultIDPCacheOpenConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("couldn't create cache store: %s", err)
|
t.Fatalf("couldn't create cache store: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
16
management/server/cache/store.go
vendored
16
management/server/cache/store.go
vendored
@@ -3,6 +3,7 @@ package cache
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -20,24 +21,27 @@ const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS"
|
|||||||
|
|
||||||
// NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar
|
// NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar
|
||||||
// to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store.
|
// to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store.
|
||||||
func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration) (store.StoreInterface, error) {
|
func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) {
|
||||||
redisAddr := os.Getenv(RedisStoreEnvVar)
|
redisAddr := os.Getenv(RedisStoreEnvVar)
|
||||||
if redisAddr != "" {
|
if redisAddr != "" {
|
||||||
return getRedisStore(ctx, redisAddr)
|
return getRedisStore(ctx, redisAddr, maxConn)
|
||||||
}
|
}
|
||||||
goc := gocache.New(maxTimeout, cleanupInterval)
|
goc := gocache.New(maxTimeout, cleanupInterval)
|
||||||
return gocache_store.NewGoCache(goc), nil
|
return gocache_store.NewGoCache(goc), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRedisStore(ctx context.Context, redisEnvAddr string) (store.StoreInterface, error) {
|
func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) {
|
||||||
options, err := redis.ParseURL(redisEnvAddr)
|
options, err := redis.ParseURL(redisEnvAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing redis cache url: %s", err)
|
return nil, fmt.Errorf("parsing redis cache url: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
options.MaxIdleConns = 6
|
options.MaxIdleConns = int(math.Ceil(float64(maxConn) * 0.5)) // 50% of max conns
|
||||||
options.MinIdleConns = 3
|
options.MinIdleConns = int(math.Ceil(float64(maxConn) * 0.1)) // 10% of max conns
|
||||||
options.MaxActiveConns = 100
|
options.MaxActiveConns = maxConn
|
||||||
|
options.ConnMaxIdleTime = 30 * time.Minute
|
||||||
|
options.ConnMaxLifetime = 0
|
||||||
|
options.PoolTimeout = 10 * time.Second
|
||||||
redisClient := redis.NewClient(options)
|
redisClient := redis.NewClient(options)
|
||||||
subCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
subCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
6
management/server/cache/store_test.go
vendored
6
management/server/cache/store_test.go
vendored
@@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMemoryStore(t *testing.T) {
|
func TestMemoryStore(t *testing.T) {
|
||||||
memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
|
memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("couldn't create memory store: %s", err)
|
t.Fatalf("couldn't create memory store: %s", err)
|
||||||
}
|
}
|
||||||
@@ -42,7 +42,7 @@ func TestMemoryStore(t *testing.T) {
|
|||||||
|
|
||||||
func TestRedisStoreConnectionFailure(t *testing.T) {
|
func TestRedisStoreConnectionFailure(t *testing.T) {
|
||||||
t.Setenv(cache.RedisStoreEnvVar, "redis://127.0.0.1:6379")
|
t.Setenv(cache.RedisStoreEnvVar, "redis://127.0.0.1:6379")
|
||||||
_, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond)
|
_, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond, 100)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("getting redis cache store should return error")
|
t.Fatal("getting redis cache store should return error")
|
||||||
}
|
}
|
||||||
@@ -65,7 +65,7 @@ func TestRedisStoreConnectionSuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv(cache.RedisStoreEnvVar, redisURL)
|
t.Setenv(cache.RedisStoreEnvVar, redisURL)
|
||||||
redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
|
redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("couldn't create redis store: %s", err)
|
t.Fatalf("couldn't create redis store: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
@@ -223,7 +225,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), &config.Config{})
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||||
|
|
||||||
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
|
||||||
|
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||||
@@ -39,7 +40,6 @@ import (
|
|||||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||||
nbpeers "github.com/netbirdio/netbird/management/server/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -105,6 +105,7 @@ func NewAPIHandler(
|
|||||||
accountManager.SyncUserJWTGroups,
|
accountManager.SyncUserJWTGroups,
|
||||||
accountManager.GetUserFromUserAuth,
|
accountManager.GetUserFromUserAuth,
|
||||||
rateLimitingConfig,
|
rateLimitingConfig,
|
||||||
|
appMetrics.GetMeter(),
|
||||||
)
|
)
|
||||||
|
|
||||||
corsMiddleware := cors.AllowAll()
|
corsMiddleware := cors.AllowAll()
|
||||||
|
|||||||
@@ -48,6 +48,29 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||||
|
|
||||||
|
// Check if filtering by name
|
||||||
|
groupName := r.URL.Query().Get("name")
|
||||||
|
if groupName != "" {
|
||||||
|
// Get single group by name
|
||||||
|
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return as array with single element to maintain API consistency
|
||||||
|
groupsResponse := []*api.Group{toGroupResponse(accountPeers, group)}
|
||||||
|
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all groups
|
||||||
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
|
|||||||
@@ -60,12 +60,23 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
|
|||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
},
|
},
|
||||||
|
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
|
||||||
|
groups := []*types.Group{
|
||||||
|
{ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT},
|
||||||
|
{ID: "id-existed", Name: "Existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI},
|
||||||
|
{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI},
|
||||||
|
}
|
||||||
|
|
||||||
|
groups = append(groups, initGroups...)
|
||||||
|
|
||||||
|
return groups, nil
|
||||||
|
},
|
||||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
|
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
|
||||||
if groupName == "All" {
|
if groupName == "All" {
|
||||||
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
|
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unknown group name")
|
return nil, status.Errorf(status.NotFound, "unknown group name")
|
||||||
},
|
},
|
||||||
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||||
return maps.Values(TestPeers), nil
|
return maps.Values(TestPeers), nil
|
||||||
@@ -287,6 +298,84 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetAllGroups(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expectedStatus int
|
||||||
|
expectedBody bool
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
expectedCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get All Groups",
|
||||||
|
expectedBody: true,
|
||||||
|
requestType: http.MethodGet,
|
||||||
|
requestPath: "/api/groups",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedCount: 3, // id-jwt-group, id-existed, id-all
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get Group By Name - Existing",
|
||||||
|
expectedBody: true,
|
||||||
|
requestType: http.MethodGet,
|
||||||
|
requestPath: "/api/groups?name=All",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get Group By Name - Not Found",
|
||||||
|
expectedBody: false,
|
||||||
|
requestType: http.MethodGet,
|
||||||
|
requestPath: "/api/groups?name=NonExistent",
|
||||||
|
expectedStatus: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
p := initGroupTestData()
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
|
UserId: "test_user",
|
||||||
|
Domain: "hotmail.com",
|
||||||
|
AccountId: "test_id",
|
||||||
|
})
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if status := recorder.Code; status != tc.expectedStatus {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||||
|
status, tc.expectedStatus)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tc.expectedBody {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var groups []api.Group
|
||||||
|
if err = json.Unmarshal(content, &groups); err != nil {
|
||||||
|
t.Fatalf("Response is not in correct json format; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectedCount, len(groups))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDeleteGroup(t *testing.T) {
|
func TestDeleteGroup(t *testing.T) {
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user