mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-19 14:29:55 +00:00
Compare commits
22 Commits
embedded-v
...
v0.73.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60a9544656 | ||
|
|
d3710d4bb2 | ||
|
|
ee360963f9 | ||
|
|
8d9580e491 | ||
|
|
5bd7c6c7ea | ||
|
|
8ae2cd0a08 | ||
|
|
e4397d4d46 | ||
|
|
6fbc90b4d3 | ||
|
|
5095e17cc5 | ||
|
|
6df0175607 | ||
|
|
3c23700e56 | ||
|
|
38ad2b67e8 | ||
|
|
01aa49433e | ||
|
|
08a2b63675 | ||
|
|
b3f9e6588a | ||
|
|
967e2d6864 | ||
|
|
e7c1d364c3 | ||
|
|
a44198fd77 | ||
|
|
b57f714350 | ||
|
|
f893abc41d | ||
|
|
60067619a1 | ||
|
|
cd777395f2 |
55
.github/workflows/release.yml
vendored
55
.github/workflows/release.yml
vendored
@@ -9,10 +9,13 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.5"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
SIGN_PIPE_VER: "v0.1.6"
|
||||
GORELEASER_VER: "v2.16.0"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
flags: ""
|
||||
SKIP_PUBLISH: "true"
|
||||
SKIP_DOCKER_PUSH: "false"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -130,8 +133,6 @@ jobs:
|
||||
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
|
||||
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
@@ -143,8 +144,27 @@ jobs:
|
||||
id: semver_parser
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Set snapshot flag
|
||||
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: |
|
||||
echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
- name: Set build vars
|
||||
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: |
|
||||
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
|
||||
echo "x-${{ github.repository }}"
|
||||
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
|
||||
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
|
||||
else
|
||||
echo "x-${{ github.repository }}"
|
||||
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
|
||||
fi
|
||||
|
||||
if [[ "x-${{ github.repository }}" != "x-netbirdio/netbird" ]]; then
|
||||
echo "SKIP_DOCKER_PUSH=true" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
@@ -161,6 +181,8 @@ jobs:
|
||||
${{ runner.os }}-go-releaser-
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
- name: run openapi generator
|
||||
run: bash shared/management/http/api/generate.sh
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
@@ -210,6 +232,8 @@ jobs:
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
|
||||
SKIP_DOCKER_PUSH: ${{ env.SKIP_DOCKER_PUSH }}
|
||||
- name: Verify RPM signatures
|
||||
run: |
|
||||
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||
@@ -332,8 +356,22 @@ jobs:
|
||||
id: semver_parser
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Set snapshot flag
|
||||
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: |
|
||||
echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
- name: Set build vars
|
||||
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: |
|
||||
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
|
||||
echo "x-${{ github.repository }}"
|
||||
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
|
||||
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
|
||||
else
|
||||
echo "x-${{ github.repository }}"
|
||||
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
@@ -393,6 +431,7 @@ jobs:
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
|
||||
- name: Verify RPM signatures
|
||||
run: |
|
||||
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||
|
||||
862
.goreleaser.yaml
862
.goreleaser.yaml
@@ -1,5 +1,7 @@
|
||||
version: 2
|
||||
|
||||
env:
|
||||
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
|
||||
- SKIP_DOCKER_PUSH={{ if index .Env "SKIP_DOCKER_PUSH" }}{{ .Env.SKIP_DOCKER_PUSH }}{{ else }}false{{ end }}
|
||||
project_name: netbird
|
||||
builds:
|
||||
- id: netbird-wasm
|
||||
@@ -74,6 +76,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -88,6 +92,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -102,6 +108,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -122,6 +130,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -136,6 +146,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -150,6 +162,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -170,6 +184,8 @@ builds:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
goarm:
|
||||
- 7
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
@@ -222,670 +238,192 @@ nfpms:
|
||||
rpm:
|
||||
signature:
|
||||
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||
dockers:
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: signal/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: signal/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: signal/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile.debug
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile.debug
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: management/Dockerfile.debug
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-server
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: combined/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-proxy
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: proxy/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:latest
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:{{ .Version }}-rootless
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:rootless-latest
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: netbirdio/relay:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/relay:latest
|
||||
image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/signal:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- netbirdio/signal:{{ .Version }}-arm
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/signal:latest
|
||||
image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- netbirdio/signal:{{ .Version }}-arm
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/management:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm64v8
|
||||
- netbirdio/management:{{ .Version }}-arm
|
||||
- netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/management:latest
|
||||
image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm64v8
|
||||
- netbirdio/management:{{ .Version }}-arm
|
||||
- netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/management:debug-latest
|
||||
image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- netbirdio/management:{{ .Version }}-debug-arm
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
- name_template: netbirdio/upload:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/upload:latest
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/relay:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/signal:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:debug-latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/upload:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird-server:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||
dockers_v2:
|
||||
- id: netbird
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird
|
||||
images:
|
||||
- netbirdio/netbird
|
||||
- ghcr.io/netbirdio/netbird
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm/6
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: netbird-rootless
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird
|
||||
images:
|
||||
- netbirdio/netbird
|
||||
- ghcr.io/netbirdio/netbird
|
||||
tags:
|
||||
- "v{{ .Version }}-rootless"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm/6
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: relay
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-relay
|
||||
images:
|
||||
- netbirdio/relay
|
||||
- ghcr.io/netbirdio/relay
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: relay/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: signal
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-signal
|
||||
images:
|
||||
- netbirdio/signal
|
||||
- ghcr.io/netbirdio/signal
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: signal/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: management
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
images:
|
||||
- netbirdio/management
|
||||
- ghcr.io/netbirdio/management
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: management/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: upload
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-upload
|
||||
images:
|
||||
- netbirdio/upload
|
||||
- ghcr.io/netbirdio/upload
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: upload-server/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: netbird-server
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-server
|
||||
images:
|
||||
- netbirdio/netbird-server
|
||||
- ghcr.io/netbirdio/netbird-server
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: combined/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
- id: netbird-proxy
|
||||
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
|
||||
ids:
|
||||
- netbird-proxy
|
||||
images:
|
||||
- netbirdio/reverse-proxy
|
||||
- ghcr.io/netbirdio/reverse-proxy
|
||||
tags:
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: proxy/Dockerfile
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/arm
|
||||
annotations:
|
||||
"org.opencontainers.image.created": "{{.Date}}"
|
||||
"org.opencontainers.image.title": "{{.ProjectName}}"
|
||||
"org.opencontainers.image.version": "{{.Version}}"
|
||||
"org.opencontainers.image.revision": "{{.FullCommit}}"
|
||||
"org.opencontainers.image.source": "{{.GitURL}}"
|
||||
"maintainer": "dev@netbird.io"
|
||||
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
skip_upload: "{{ .Env.SKIP_PUBLISH }}"
|
||||
repository:
|
||||
owner: netbirdio
|
||||
name: homebrew-tap
|
||||
@@ -902,6 +440,7 @@ brews:
|
||||
|
||||
uploads:
|
||||
- name: debian
|
||||
skip: "{{ .Env.SKIP_PUBLISH }}"
|
||||
ids:
|
||||
- netbird_deb
|
||||
mode: archive
|
||||
@@ -910,6 +449,7 @@ uploads:
|
||||
method: PUT
|
||||
|
||||
- name: yum
|
||||
skip: "{{ .Env.SKIP_PUBLISH }}"
|
||||
ids:
|
||||
- netbird_rpm
|
||||
mode: archive
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
version: 2
|
||||
|
||||
env:
|
||||
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
|
||||
project_name: netbird-ui
|
||||
builds:
|
||||
- id: netbird-ui
|
||||
@@ -101,6 +102,7 @@ nfpms:
|
||||
|
||||
uploads:
|
||||
- name: debian
|
||||
skip: "{{ .Env.SKIP_PUBLISH }}"
|
||||
ids:
|
||||
- netbird_ui_deb
|
||||
mode: archive
|
||||
@@ -109,6 +111,7 @@ uploads:
|
||||
method: PUT
|
||||
|
||||
- name: yum
|
||||
skip: "{{ .Env.SKIP_PUBLISH }}"
|
||||
ids:
|
||||
- netbird_ui_rpm
|
||||
mode: archive
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.23.3
|
||||
FROM alpine:3.24
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
@@ -21,7 +21,7 @@ ENV \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
ARG TARGETPLATFORM
|
||||
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
|
||||
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
|
||||
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.22.0
|
||||
FROM alpine:3.24
|
||||
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
@@ -27,7 +27,7 @@ ENV \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
ARG TARGETPLATFORM
|
||||
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
|
||||
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
|
||||
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -24,6 +23,7 @@ const (
|
||||
|
||||
// Profile represents a profile for gomobile
|
||||
type Profile struct {
|
||||
ID string
|
||||
Name string
|
||||
IsActive bool
|
||||
}
|
||||
@@ -53,10 +53,10 @@ func (p *ProfileArray) Get(i int) *Profile {
|
||||
├── state.json ← Default profile state
|
||||
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
|
||||
└── profiles/ ← Subdirectory for non-default profiles
|
||||
├── work.json ← Work profile config
|
||||
├── work.state.json ← Work profile state
|
||||
├── personal.json ← Personal profile config
|
||||
└── personal.state.json ← Personal profile state
|
||||
├── work.json ← Legacy work profile config
|
||||
├── work.state.json ← Legacy work profile state
|
||||
├── 4c5f5c8198c3989cffb5b5394f5a7ae0.json ← ID profile config
|
||||
├── 4c5f5c8198c3989cffb5b5394f5a7ae0.state.json ← ID profile state
|
||||
*/
|
||||
|
||||
// ProfileManager manages profiles for Android
|
||||
@@ -99,6 +99,7 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||
var profiles []*Profile
|
||||
for _, p := range internalProfiles {
|
||||
profiles = append(profiles, &Profile{
|
||||
ID: p.ID.String(),
|
||||
Name: p.Name,
|
||||
IsActive: p.IsActive,
|
||||
})
|
||||
@@ -108,55 +109,65 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the currently active profile name
|
||||
func (pm *ProfileManager) GetActiveProfile() (string, error) {
|
||||
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
activeState, err := pm.serviceMgr.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
return nil, fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return activeState.Name, nil
|
||||
|
||||
// ActiveProfileState only stores the ID (and username), not the display
|
||||
// name. Resolve the ID to the full profile so callers get the real Name.
|
||||
prof, err := pm.serviceMgr.ResolveProfile(activeState.ID.String(), androidUsername)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve active profile %q: %w", activeState.ID, err)
|
||||
}
|
||||
return &Profile{ID: prof.ID.String(), Name: prof.Name, IsActive: true}, nil
|
||||
}
|
||||
|
||||
// SwitchProfile switches to a different profile
|
||||
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||
func (pm *ProfileManager) SwitchProfile(id string) error {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
ID: profilemanager.ID(id),
|
||||
Username: androidUsername,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("switched to profile: %s", profileName)
|
||||
log.Infof("switched to profile: %s", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddProfile creates a new profile
|
||||
func (pm *ProfileManager) AddProfile(profileName string) error {
|
||||
// Use ServiceManager (creates profile in profiles/ directory)
|
||||
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
|
||||
profile, err := pm.serviceMgr.AddProfile(profileName, androidUsername)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("created new profile: %s", profileName)
|
||||
log.Infof("created new profile: %s", profile.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogoutProfile logs out from a profile (clears authentication)
|
||||
func (pm *ProfileManager) LogoutProfile(profileName string) error {
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
configPath, err := pm.getProfileConfigPath(profileName)
|
||||
func (pm *ProfileManager) LogoutProfile(id string) error {
|
||||
configPath, err := pm.getProfileConfigPath(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
|
||||
return fmt.Errorf("id '%s' is not valid", id)
|
||||
}
|
||||
|
||||
// Check if profile exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("profile '%s' does not exist", profileName)
|
||||
return fmt.Errorf("profile '%s' does not exist", id)
|
||||
}
|
||||
|
||||
// Read current config using internal profilemanager
|
||||
@@ -174,53 +185,57 @@ func (pm *ProfileManager) LogoutProfile(profileName string) error {
|
||||
return fmt.Errorf("failed to save config: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("logged out from profile: %s", profileName)
|
||||
log.Infof("logged out from profile: %s", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveProfile deletes a profile
|
||||
func (pm *ProfileManager) RemoveProfile(profileName string) error {
|
||||
func (pm *ProfileManager) RemoveProfile(id string) error {
|
||||
// Use ServiceManager (removes profile from profiles/ directory)
|
||||
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
|
||||
if err := pm.serviceMgr.RemoveProfile(profilemanager.ID(id), androidUsername); err != nil {
|
||||
return fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("removed profile: %s", profileName)
|
||||
log.Infof("removed profile: %s", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProfileConfigPath returns the config file path for a profile
|
||||
// This is needed for Android-specific path handling (netbird.cfg for default profile)
|
||||
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
func (pm *ProfileManager) getProfileConfigPath(id string) (string, error) {
|
||||
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
|
||||
return "", fmt.Errorf("id %q is not valid", id)
|
||||
}
|
||||
|
||||
if id == profilemanager.DefaultProfileName {
|
||||
// Android uses netbird.cfg for default profile instead of default.json
|
||||
// Default profile is stored in root configDir, not in profiles/
|
||||
return filepath.Join(pm.configDir, defaultConfigFilename), nil
|
||||
}
|
||||
|
||||
// Non-default profiles are stored in profiles subdirectory
|
||||
// This matches the Java Preferences.java expectation
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, profileName+".json"), nil
|
||||
return filepath.Join(profilesDir, id+".json"), nil
|
||||
}
|
||||
|
||||
// GetConfigPath returns the config file path for a given profile
|
||||
// GetConfigPath returns the config file path for a given profile id
|
||||
// Java should call this instead of constructing paths with Preferences.configFile()
|
||||
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
|
||||
return pm.getProfileConfigPath(profileName)
|
||||
func (pm *ProfileManager) GetConfigPath(id string) (string, error) {
|
||||
return pm.getProfileConfigPath(id)
|
||||
}
|
||||
|
||||
// GetStateFilePath returns the state file path for a given profile
|
||||
// Java should call this instead of constructing paths with Preferences.stateFile()
|
||||
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
func (pm *ProfileManager) GetStateFilePath(id string) (string, error) {
|
||||
if id == "" || id == profilemanager.DefaultProfileName {
|
||||
return filepath.Join(pm.configDir, "state.json"), nil
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
|
||||
return "", fmt.Errorf("id %q is not valid", id)
|
||||
}
|
||||
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, profileName+".state.json"), nil
|
||||
return filepath.Join(profilesDir, id+".state.json"), nil
|
||||
}
|
||||
|
||||
// GetActiveConfigPath returns the config file path for the currently active profile
|
||||
@@ -230,7 +245,7 @@ func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetConfigPath(activeProfile)
|
||||
return pm.GetConfigPath(activeProfile.ID)
|
||||
}
|
||||
|
||||
// GetActiveStateFilePath returns the state file path for the currently active profile
|
||||
@@ -240,18 +255,5 @@ func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetStateFilePath(activeProfile)
|
||||
}
|
||||
|
||||
// sanitizeProfileName removes invalid characters from profile name
|
||||
func sanitizeProfileName(name string) string {
|
||||
// Keep only alphanumeric, underscore, and hyphen
|
||||
var result strings.Builder
|
||||
for _, r := range name {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
return pm.GetStateFilePath(activeProfile.ID)
|
||||
}
|
||||
|
||||
@@ -96,17 +96,19 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
|
||||
}
|
||||
|
||||
handle := activeProf.ID.String()
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
ProfileName: &activeProf.Name,
|
||||
ProfileName: &handle,
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
profileState, err := pm.GetProfileState(activeProf.ID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
@@ -170,14 +172,13 @@ func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, pr
|
||||
return activeProf, nil
|
||||
}
|
||||
|
||||
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
|
||||
err := switchProfile(context.Background(), profileName, username)
|
||||
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, handle string, username string) error {
|
||||
resolvedID, err := switchProfile(ctx, handle, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile on daemon: %v", err)
|
||||
}
|
||||
|
||||
err = pm.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
@@ -205,11 +206,15 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
|
||||
return nil
|
||||
}
|
||||
|
||||
func switchProfile(ctx context.Context, profileName string, username string) error {
|
||||
// switchProfile asks the daemon to switch to the profile identified by
|
||||
// handle (a name, ID, or unique ID prefix). Returns the resolved profile
|
||||
// ID so the caller can update the local active-profile state without
|
||||
// re-resolving the handle.
|
||||
func switchProfile(ctx context.Context, handle string, username string) (profilemanager.ID, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
@@ -217,15 +222,15 @@ func switchProfile(ctx context.Context, profileName string, username string) err
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profileName,
|
||||
resp, err := client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &handle,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile failed: %v", err)
|
||||
return "", fmt.Errorf("switch profile failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
}
|
||||
|
||||
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
|
||||
@@ -249,7 +254,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -277,7 +282,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string, profileID profilemanager.ID) error {
|
||||
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
@@ -291,7 +296,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
|
||||
jwtToken := ""
|
||||
if setupKey == "" && needsLogin {
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -306,10 +311,10 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileID profilemanager.ID) (*auth.TokenInfo, error) {
|
||||
hint := ""
|
||||
pm := profilemanager.NewProfileManager()
|
||||
profileState, err := pm.GetProfileState(profileName)
|
||||
profileState, err := pm.GetProfileState(profileID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
|
||||
@@ -27,7 +27,7 @@ func TestLogin(t *testing.T) {
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
sm := profilemanager.ServiceManager{}
|
||||
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: "default",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -2,11 +2,16 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
@@ -14,6 +19,8 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var profileListShowID bool
|
||||
|
||||
var profileCmd = &cobra.Command{
|
||||
Use: "profile",
|
||||
Short: "Manage NetBird client profiles",
|
||||
@@ -31,27 +38,40 @@ var profileListCmd = &cobra.Command{
|
||||
var profileAddCmd = &cobra.Command{
|
||||
Use: "add <profile_name>",
|
||||
Short: "Add a new profile",
|
||||
Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
|
||||
Long: `Add a new profile. Profile name is free-form, a unique ID is generated for the on-disk config file.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: addProfileFunc,
|
||||
}
|
||||
|
||||
var profileRenameCmd = &cobra.Command{
|
||||
Use: "rename <profile> <new_profile_name>",
|
||||
Short: "Renames an existing profile",
|
||||
Long: `Renames an existing profile (by a name, ID, or unique ID prefix). Profile name is free-form.`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: renameProfileFunc,
|
||||
}
|
||||
|
||||
var profileRemoveCmd = &cobra.Command{
|
||||
Use: "remove <profile_name>",
|
||||
Short: "Remove a profile",
|
||||
Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: removeProfileFunc,
|
||||
Use: "remove <profile>",
|
||||
Short: "Remove a profile",
|
||||
Long: `Remove a profile by name, ID, or unique ID prefix.`,
|
||||
Aliases: []string{"rm"},
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: removeProfileFunc,
|
||||
}
|
||||
|
||||
var profileSelectCmd = &cobra.Command{
|
||||
Use: "select <profile_name>",
|
||||
Use: "select <profile>",
|
||||
Short: "Select a profile",
|
||||
Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
|
||||
Long: `Make the specified profile active. Accepts a name, ID, or unique ID prefix.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: selectProfileFunc,
|
||||
}
|
||||
|
||||
func init() {
|
||||
profileListCmd.Flags().BoolVar(&profileListShowID, "show-id", false, "show the profile ID column")
|
||||
}
|
||||
|
||||
func setupCmd(cmd *cobra.Command) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
@@ -65,6 +85,7 @@ func setupCmd(cmd *cobra.Command) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
@@ -83,25 +104,33 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
|
||||
resp, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// list profiles, add a tick if the profile is active
|
||||
cmd.Println("Found", len(profiles.Profiles), "profiles:")
|
||||
for _, profile := range profiles.Profiles {
|
||||
// use a cross to indicate the passive profiles
|
||||
activeMarker := "✗"
|
||||
if profile.IsActive {
|
||||
activeMarker = "✓"
|
||||
}
|
||||
cmd.Println(activeMarker, profile.Name)
|
||||
tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
|
||||
if profileListShowID {
|
||||
fmt.Fprintln(tw, "ID\tNAME\tACTIVE")
|
||||
} else {
|
||||
fmt.Fprintln(tw, "NAME\tACTIVE")
|
||||
}
|
||||
|
||||
return nil
|
||||
for _, profile := range resp.Profiles {
|
||||
marker := ""
|
||||
if profile.IsActive {
|
||||
marker = "✓"
|
||||
}
|
||||
name := profilemanager.StripCtrlChars(profile.Name)
|
||||
id := profilemanager.ID(profile.Id)
|
||||
if profileListShowID {
|
||||
fmt.Fprintf(tw, "%s\t%s\t%s\n", id.ShortID(), name, marker)
|
||||
} else {
|
||||
fmt.Fprintf(tw, "%s\t%s\n", name, marker)
|
||||
}
|
||||
}
|
||||
return tw.Flush()
|
||||
}
|
||||
|
||||
func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -121,21 +150,82 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profileName := args[0]
|
||||
|
||||
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("add profile request: %w", err)
|
||||
}
|
||||
|
||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
||||
if dupCount > 1 {
|
||||
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, profileName)
|
||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||
}
|
||||
|
||||
id := profilemanager.ID(resp.Id)
|
||||
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func renameProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Profile added successfully:", profileName)
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
handle := args[0]
|
||||
newProfilename := args[1]
|
||||
|
||||
resp, err := daemonClient.RenameProfile(cmd.Context(), &proto.RenameProfileRequest{
|
||||
Handle: handle,
|
||||
Username: currUser.Username,
|
||||
NewProfileName: newProfilename,
|
||||
})
|
||||
if err != nil {
|
||||
return wrapAmbiguityError(err, handle)
|
||||
}
|
||||
|
||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, newProfilename)
|
||||
if dupCount > 1 {
|
||||
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, newProfilename)
|
||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||
}
|
||||
|
||||
cmd.Printf("Profile renamed from %s to %s\n", profilemanager.StripCtrlChars(resp.OldProfileName), profilemanager.StripCtrlChars(newProfilename))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func countProfilesWithName(ctx context.Context, c proto.DaemonServiceClient, username, name string) (int, error) {
|
||||
resp, err := c.ListProfiles(ctx, &proto.ListProfilesRequest{Username: username})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := 0
|
||||
for _, p := range resp.Profiles {
|
||||
if p.Name == name {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func removeProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
@@ -153,18 +243,17 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
handle := args[0]
|
||||
|
||||
profileName := args[0]
|
||||
|
||||
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
|
||||
ProfileName: profileName,
|
||||
resp, err := daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
|
||||
ProfileName: handle,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return wrapAmbiguityError(err, handle)
|
||||
}
|
||||
|
||||
cmd.Println("Profile removed successfully:", profileName)
|
||||
cmd.Printf("Profile removed: %s\n", resp.Id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -174,7 +263,7 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
profileManager := profilemanager.NewProfileManager()
|
||||
profileName := args[0]
|
||||
handle := args[0]
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
@@ -191,32 +280,15 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
|
||||
Username: currUser.Username,
|
||||
switchResp, err := daemonClient.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &handle,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list profiles: %w", err)
|
||||
return wrapAmbiguityError(err, handle)
|
||||
}
|
||||
|
||||
var profileExists bool
|
||||
|
||||
for _, profile := range profiles.Profiles {
|
||||
if profile.Name == profileName {
|
||||
profileExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !profileExists {
|
||||
return fmt.Errorf("profile %s does not exist", profileName)
|
||||
}
|
||||
|
||||
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = profileManager.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
if err := profileManager.SwitchProfile(profilemanager.ID(switchResp.Id)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -231,6 +303,30 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("Profile switched successfully to:", profileName)
|
||||
id := profilemanager.ID(switchResp.Id)
|
||||
cmd.Printf("Profile switched to: %s\n", id.ShortID())
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapAmbiguityError turns the daemon's gRPC InvalidArgument errors
|
||||
// (which carry the resolver's message verbatim) into CLI-friendly text
|
||||
// that points the user at --show-id.
|
||||
func wrapAmbiguityError(err error, handle string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
st, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
switch st.Code() {
|
||||
case codes.InvalidArgument:
|
||||
msg := st.Message()
|
||||
if strings.Contains(msg, "ambiguous") {
|
||||
return errors.New(msg + "\nRun `netbird profile list --show-id` to see IDs, then select by ID prefix:\n netbird profile select|remove <id-prefix>")
|
||||
}
|
||||
case codes.NotFound:
|
||||
return fmt.Errorf("profile %q not found", handle)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -190,6 +190,7 @@ func init() {
|
||||
// profile commands
|
||||
profileCmd.AddCommand(profileListCmd)
|
||||
profileCmd.AddCommand(profileAddCmd)
|
||||
profileCmd.AddCommand(profileRenameCmd)
|
||||
profileCmd.AddCommand(profileRemoveCmd)
|
||||
profileCmd.AddCommand(profileSelectCmd)
|
||||
|
||||
|
||||
@@ -128,13 +128,12 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
var profileSwitched bool
|
||||
// switch profile if provided
|
||||
if profileName != "" {
|
||||
err = switchProfile(cmd.Context(), profileName, username.Username)
|
||||
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
err = pm.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
@@ -190,7 +189,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -261,10 +260,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
}
|
||||
|
||||
// set the new config
|
||||
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
|
||||
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.ID.String(), username.Username)
|
||||
if _, err := client.SetConfig(ctx, req); err != nil {
|
||||
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
|
||||
log.Warnf("setConfig method is not available in the daemon")
|
||||
log.Warnf("setConfig method is not available in the daemon: %s", st.Message())
|
||||
} else {
|
||||
return fmt.Errorf("call service setConfig method: %v", err)
|
||||
}
|
||||
@@ -289,10 +288,11 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
return fmt.Errorf("setup login request: %v", err)
|
||||
}
|
||||
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
profileID := activeProf.ID.String()
|
||||
loginRequest.ProfileName = &profileID
|
||||
loginRequest.Username = &username
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
profileState, err := pm.GetProfileState(activeProf.ID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
@@ -329,7 +329,7 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{
|
||||
ProfileName: &activeProf.Name,
|
||||
ProfileName: &profileID,
|
||||
Username: &username,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
|
||||
@@ -29,14 +29,14 @@ func TestUpDaemon(t *testing.T) {
|
||||
}
|
||||
|
||||
sm := profilemanager.ServiceManager{}
|
||||
err = sm.AddProfile("test1", currUser.Username)
|
||||
created, err := sm.AddProfile("test1", currUser.Username)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add profile: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "test1",
|
||||
ID: created.ID,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -41,7 +41,6 @@ type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn udpmux.FilterFn
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
|
||||
@@ -61,12 +60,11 @@ type ICEBind struct {
|
||||
ipv6Conn *net.UDPConn
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
address: address,
|
||||
mtu: mtu,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
@@ -265,7 +263,6 @@ func (s *ICEBind) createOrUpdateMux() {
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: muxConn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
MTU: s.mtu,
|
||||
},
|
||||
|
||||
@@ -289,7 +289,7 @@ func setupICEBind(t *testing.T) *ICEBind {
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/10"),
|
||||
}
|
||||
return NewICEBind(transportNet, nil, address, 1280)
|
||||
return NewICEBind(transportNet, address, 1280)
|
||||
}
|
||||
|
||||
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
@@ -41,10 +44,13 @@ type PacketCapture interface {
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
|
||||
filter PacketFilter
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
filter PacketFilter
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
// panicHandler is invoked after a panic in the underlying device is
|
||||
// recovered in Read or Write.
|
||||
panicHandler atomic.Pointer[func()]
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// newDeviceFilter constructor function
|
||||
@@ -70,7 +76,7 @@ func (d *FilteredDevice) Close() error {
|
||||
|
||||
// Read wraps read method with filtering feature
|
||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||
if n, err = d.deviceRead(bufs, sizes, offset); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -112,7 +118,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if filter == nil {
|
||||
return d.Device.Write(bufs, offset)
|
||||
return d.deviceWrite(bufs, offset)
|
||||
}
|
||||
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
@@ -125,9 +131,44 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
n, err := d.Device.Write(filteredBufs, offset)
|
||||
n += dropped
|
||||
return n, err
|
||||
n, err := d.deviceWrite(filteredBufs, offset)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n + dropped, nil
|
||||
}
|
||||
|
||||
// deviceRead calls the underlying device Read, recovering from panics in the
|
||||
// wintun read path and converting them into errors.
|
||||
func (d *FilteredDevice) deviceRead(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
defer d.recoverFromPanic("read", &n, &err)
|
||||
return d.Device.Read(bufs, sizes, offset)
|
||||
}
|
||||
|
||||
// deviceWrite calls the underlying device Write, recovering from panics in the
|
||||
// wintun write path and converting them into errors.
|
||||
func (d *FilteredDevice) deviceWrite(bufs [][]byte, offset int) (n int, err error) {
|
||||
defer d.recoverFromPanic("write", &n, &err)
|
||||
return d.Device.Write(bufs, offset)
|
||||
}
|
||||
|
||||
// recoverFromPanic converts a panic in the underlying device into a regular
|
||||
// error and invokes the registered panic handler. The wintun read path is
|
||||
// known to panic on zero-length packets that third-party filter drivers can
|
||||
// place in the ring.
|
||||
func (d *FilteredDevice) recoverFromPanic(op string, n *int, err *error) {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("recovered panic in tun device %s: %v\n%s", op, r, debug.Stack())
|
||||
*n = 0
|
||||
*err = fmt.Errorf("tun device %s panic: %v", op, r)
|
||||
|
||||
if handler := d.panicHandler.Load(); handler != nil {
|
||||
(*handler)()
|
||||
}
|
||||
}
|
||||
|
||||
// SetFilter sets packet filter to device
|
||||
@@ -137,6 +178,17 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
|
||||
d.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetPanicHandler registers a handler invoked after a recovered panic in Read
|
||||
// or Write. The device is unusable after such a panic; the handler should
|
||||
// trigger recreation of the interface. Pass nil to remove.
|
||||
func (d *FilteredDevice) SetPanicHandler(handler func()) {
|
||||
if handler == nil {
|
||||
d.panicHandler.Store(nil)
|
||||
return
|
||||
}
|
||||
d.panicHandler.Store(&handler)
|
||||
}
|
||||
|
||||
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
|
||||
// Uses atomic store so the hot path (Read/Write) is a single pointer load
|
||||
// with no locking overhead when capture is off.
|
||||
|
||||
@@ -221,3 +221,60 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeviceWrapperReadPanic(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tun := mocks.NewMockDevice(ctrl)
|
||||
tun.EXPECT().Read(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
// Reproduce the wintun zero-length packet panic (index out of range).
|
||||
packet := make([]byte, 0)
|
||||
return int(packet[0]), nil
|
||||
})
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
|
||||
handlerCalled := false
|
||||
wrapped.SetPanicHandler(func() { handlerCalled = true })
|
||||
|
||||
n, err := wrapped.Read([][]byte{{}}, []int{0}, 0)
|
||||
if err == nil {
|
||||
t.Errorf("expected error from recovered panic, got nil")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("expected n=0, got %d", n)
|
||||
}
|
||||
if !handlerCalled {
|
||||
t.Errorf("expected panic handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceWrapperWritePanic(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tun := mocks.NewMockDevice(ctrl)
|
||||
tun.EXPECT().Write(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(bufs [][]byte, offset int) (int, error) {
|
||||
packet := make([]byte, 0)
|
||||
return int(packet[0]), nil
|
||||
})
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
|
||||
handlerCalled := false
|
||||
wrapped.SetPanicHandler(func() { handlerCalled = true })
|
||||
|
||||
n, err := wrapped.Write([][]byte{{0x45, 0x00}}, 0)
|
||||
if err == nil {
|
||||
t.Errorf("expected error from recovered panic, got nil")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("expected n=0, got %d", n)
|
||||
}
|
||||
if !handlerCalled {
|
||||
t.Errorf("expected panic handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,8 +32,6 @@ type TunKernelDevice struct {
|
||||
link *wgLink
|
||||
udpMuxConn net.PacketConn
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
|
||||
filterFn udpmux.FilterFn
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
|
||||
@@ -104,7 +102,6 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
bindParams := udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
MTU: t.mtu,
|
||||
}
|
||||
|
||||
@@ -63,7 +63,6 @@ type WGIFaceOpts struct {
|
||||
MTU uint16
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn udpmux.FilterFn
|
||||
DisableDNS bool
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
|
||||
var tun WGTunDevice
|
||||
if netstack.IsEnabled() {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
wgIFace := &WGIface{
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
@@ -30,7 +30,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
|
||||
return &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
|
||||
userspaceBind: true,
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -22,10 +20,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// FilterFn is a function that filters out candidates based on the address.
|
||||
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
|
||||
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
||||
|
||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
||||
// It then passes packets to the UDPMux that does the actual connection muxing.
|
||||
type UniversalUDPMuxDefault struct {
|
||||
@@ -43,7 +37,6 @@ type UniversalUDPMuxParams struct {
|
||||
UDPConn net.PacketConn
|
||||
XORMappedAddrCacheTTL time.Duration
|
||||
Net transport.Net
|
||||
FilterFn FilterFn
|
||||
WGAddress wgaddr.Address
|
||||
MTU uint16
|
||||
}
|
||||
@@ -68,7 +61,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
PacketConn: params.UDPConn,
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
filterFn: params.FilterFn,
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
@@ -115,15 +107,12 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides WriteTo to drop packets destined for the overlay subnet.
|
||||
type UDPConn struct {
|
||||
net.PacketConn
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
filterFn FilterFn
|
||||
// TODO: reset cache on route changes
|
||||
addrCache sync.Map
|
||||
address wgaddr.Address
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
// GetPacketConn returns the underlying PacketConn
|
||||
@@ -132,65 +121,16 @@ func (u *UDPConn) GetPacketConn() net.PacketConn {
|
||||
}
|
||||
|
||||
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if u.filterFn == nil {
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
if isRouted, found := u.addrCache.Load(addr.String()); found {
|
||||
return u.handleCachedAddress(isRouted.(bool), b, addr)
|
||||
}
|
||||
|
||||
return u.handleUncachedAddress(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
if isRouted {
|
||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
if err := u.performFilterCheck(addr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
host, err := getHostFromAddr(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse address %s: %v", addr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a) {
|
||||
dst := udpAddr.AddrPort().Addr().Unmap()
|
||||
if (u.address.Network.IsValid() && u.address.Network.Contains(dst)) || (u.address.IPv6Net.IsValid() && u.address.IPv6Net.Contains(dst)) {
|
||||
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return 0, fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||
} else {
|
||||
u.addrCache.Store(addr.String(), isRouted)
|
||||
if isRouted {
|
||||
// Extra log, as the error only shows up with ICE logging enabled
|
||||
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHostFromAddr(addr net.Addr) (string, error) {
|
||||
host, _, err := net.SplitHostPort(addr.String())
|
||||
return host, err
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// GetSharedConn returns the shared udp conn
|
||||
@@ -225,6 +165,13 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
|
||||
return nil
|
||||
}
|
||||
|
||||
src := udpAddr.AddrPort().Addr().Unmap()
|
||||
wg := m.params.WGAddress
|
||||
if (wg.Network.IsValid() && wg.Network.Contains(src)) || (wg.IPv6Net.IsValid() && wg.IPv6Net.Contains(src)) {
|
||||
log.Debugf("dropping STUN message from overlay source %s", udpAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.isXORMappedResponse(msg, udpAddr.String()) {
|
||||
err := m.handleXORMappedResponse(udpAddr, msg)
|
||||
if err != nil {
|
||||
|
||||
@@ -66,7 +66,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
|
||||
@@ -22,7 +22,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
|
||||
@@ -118,6 +118,8 @@ func (c *ConnectClient) RunOniOS(
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
stateFilePath string,
|
||||
cacheDir string,
|
||||
logFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
debug.SetGCPercent(5)
|
||||
@@ -127,8 +129,9 @@ func (c *ConnectClient) RunOniOS(
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
return c.run(mobileDependency, nil, logFilePath)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
|
||||
@@ -250,6 +250,7 @@ type BundleGenerator struct {
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logPath string
|
||||
tempDir string
|
||||
statePath string
|
||||
cpuProfile []byte
|
||||
capturePath string
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
@@ -276,6 +277,7 @@ type GeneratorDependencies struct {
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogPath string
|
||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||
StatePath string // Path to the state file. If empty, the ServiceManager default path is used.
|
||||
CPUProfile []byte
|
||||
CapturePath string
|
||||
RefreshStatus func()
|
||||
@@ -299,6 +301,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
syncResponse: deps.SyncResponse,
|
||||
logPath: deps.LogPath,
|
||||
tempDir: deps.TempDir,
|
||||
statePath: deps.StatePath,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
capturePath: deps.CapturePath,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
@@ -850,8 +853,11 @@ func (g *BundleGenerator) maskSecrets() {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStateFile() error {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path := sm.GetStatePath()
|
||||
path := g.statePath
|
||||
if path == "" {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path = sm.GetStatePath()
|
||||
}
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
36
client/internal/debug/debug_ios.go
Normal file
36
client/internal/debug/debug_ios.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build ios
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// swiftLogFile is the Swift app log written by the iOS app into the same log
|
||||
// directory as the Go client log, so it can be collected into the bundle.
|
||||
const swiftLogFile = "swift-log.log"
|
||||
|
||||
// addPlatformLog collects logs for the iOS debug bundle. iOS has no logcat or
|
||||
// systemd journal, so we rely on file-based logs. addLogfile handles the Go
|
||||
// client log (logPath) with rotation, the stderr/stdout companions and
|
||||
// anonymization. The iOS app writes its own Swift log into the same directory,
|
||||
// so we add it alongside the Go log.
|
||||
func (g *BundleGenerator) addPlatformLog() error {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if g.logPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
swiftLogPath := filepath.Join(filepath.Dir(g.logPath), swiftLogFile)
|
||||
if err := g.addSingleLogfile(swiftLogPath, swiftLogFile); err != nil {
|
||||
// The Swift log is best-effort: the app may not have written it yet.
|
||||
log.Warnf("failed to add %s to debug bundle: %v", swiftLogFile, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build !android && !ios
|
||||
|
||||
package debug
|
||||
|
||||
|
||||
@@ -843,6 +843,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
"Name": "non-config: profile name is not needed for debug purposes",
|
||||
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
|
||||
}
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/syncstore"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
@@ -240,7 +239,7 @@ type Engine struct {
|
||||
syncStore syncstore.Store
|
||||
syncStoreDir string
|
||||
|
||||
flowManager nftypes.FlowManager
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
updateManager *updater.Manager
|
||||
@@ -531,6 +530,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
if filteredDevice := e.wgInterface.GetDevice(); filteredDevice != nil {
|
||||
filteredDevice.SetPanicHandler(e.triggerClientRestart)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
@@ -1711,6 +1714,13 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
// Self-addressed heartbeat: the signal client's receive watchdog
|
||||
// round-trips this through the server to confirm the receive stream
|
||||
// is delivering. Liveness is already recorded before this handler.
|
||||
if msg.GetBody().GetType() == sProto.Body_HEARTBEAT {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
@@ -1909,7 +1919,6 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
WGPrivKey: e.config.WgPrivateKey.String(),
|
||||
MTU: e.config.MTU,
|
||||
TransportNet: transportNet,
|
||||
FilterFn: e.addrViaRoutes,
|
||||
DisableDNS: e.config.DisableDNS,
|
||||
}
|
||||
|
||||
@@ -2157,21 +2166,6 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
var vpnRoutes []netip.Prefix
|
||||
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
|
||||
return true, prefix, nil
|
||||
}
|
||||
|
||||
return false, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSServer() {
|
||||
if e.dnsServer == nil {
|
||||
return
|
||||
|
||||
@@ -1024,14 +1024,17 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
return d.relayStates
|
||||
}
|
||||
|
||||
// extend the list of stun, turn servers with relay address
|
||||
// extend the list of stun, turn servers with the relay server connections
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
|
||||
// if the server connection is not established then we will use the general address
|
||||
// in case of connection we will use the instance specific address
|
||||
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
// TODO add their status
|
||||
states := d.relayMgr.RelayStates()
|
||||
if len(states) == 0 {
|
||||
// no relay connection tracked yet; surface configured servers as
|
||||
// unavailable with the real reconnect error when known
|
||||
err := relayClient.ErrRelayClientNotConnected
|
||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
||||
err = connErr
|
||||
}
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
@@ -1041,10 +1044,14 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
return relayStates
|
||||
}
|
||||
|
||||
relayState := relay.ProbeResult{
|
||||
URI: instanceAddr,
|
||||
for _, rs := range states {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: rs.URL,
|
||||
Err: rs.Err,
|
||||
Transport: rs.Transport,
|
||||
})
|
||||
}
|
||||
return append(relayStates, relayState)
|
||||
return relayStates
|
||||
}
|
||||
|
||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
@@ -1405,6 +1412,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
Transport: relayState.Transport,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -165,10 +164,6 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
|
||||
return
|
||||
}
|
||||
|
||||
if candidateViaRoutes(candidate, haRoutes) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
@@ -589,34 +584,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
var routePrefixes []netip.Prefix
|
||||
for _, routes := range clientRoutes {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
routePrefixes = append(routePrefixes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range routePrefixes {
|
||||
// default route is handled by route exclusion / ip rules
|
||||
if prefix.Bits() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefix.Contains(addr) {
|
||||
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
@@ -108,6 +108,10 @@ type ConfigInput struct {
|
||||
|
||||
// Config Configuration type
|
||||
type Config struct {
|
||||
// Name is the human-readable profile name shown in CLI/UI listings.
|
||||
// It is independent of the profile's on-disk filename (which is the ID).
|
||||
Name string
|
||||
|
||||
// Wireguard private key of local peer
|
||||
PrivateKey string
|
||||
PreSharedKey string
|
||||
@@ -270,6 +274,16 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
if config.Name != "" {
|
||||
sanitized, err := sanitizeDisplayName(config.Name)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid profile name: %w", err)
|
||||
}
|
||||
if sanitized != config.Name {
|
||||
config.Name = sanitized
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
if config.ManagementURL == nil {
|
||||
log.Infof("using default Management URL %s", DefaultManagementURL)
|
||||
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
|
||||
|
||||
118
client/internal/profilemanager/id.go
Normal file
118
client/internal/profilemanager/id.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
// profileIDByteLen is the number of random bytes generated for a new
|
||||
// profile ID. The resulting hex string is twice this length.
|
||||
profileIDByteLen = 16
|
||||
|
||||
// shortIDLen is the number of leading characters of an ID we render in
|
||||
// list output. Profiles per device are few, so 8 chars is collision-safe
|
||||
// in practice and easy to type as a prefix.
|
||||
shortIDLen = 8
|
||||
|
||||
// maxProfileNameLen caps the human-readable profile name to keep table
|
||||
// output legible and prevent denial-of-service via huge JSON fields.
|
||||
maxProfileNameLen = 128
|
||||
|
||||
// maxProfileIDLen bounds the on-disk filename we'll accept. New
|
||||
// IDs are 32 hex chars, legacy stems are sanitized profile names. The
|
||||
// cap is generous enough to cover both without permitting absurdly
|
||||
// long filenames.
|
||||
maxProfileIDLen = 64
|
||||
)
|
||||
|
||||
type ID string
|
||||
|
||||
// generateProfileID returns a new random hex ID for a profile file.
|
||||
func generateProfileID() (ID, error) {
|
||||
buf := make([]byte, profileIDByteLen)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", fmt.Errorf("read random bytes: %w", err)
|
||||
}
|
||||
return ID(hex.EncodeToString(buf)), nil
|
||||
}
|
||||
|
||||
// IsValidProfileFilenameStem reports whether id is safe to use as the stem
|
||||
// of a profile JSON filename.
|
||||
func IsValidProfileFilenameStem(id ID) bool {
|
||||
s := id.String()
|
||||
if s == "" || len(s) > maxProfileIDLen {
|
||||
return false
|
||||
}
|
||||
if s == defaultProfileName {
|
||||
return true
|
||||
}
|
||||
if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") {
|
||||
return false
|
||||
}
|
||||
// filepath.Base catches any leftover separators on platforms with
|
||||
// exotic path conventions.
|
||||
if filepath.Base(s) != s {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// sanitizeDisplayName normalizes a user-supplied profile display name for
|
||||
// storage. It strips ASCII control characters, rejects invalid UTF-8, and
|
||||
// caps the length. Emojis, spaces, punctuation, and non-ASCII letters are
|
||||
// preserved. Returns an error if nothing usable remains.
|
||||
func sanitizeDisplayName(name string) (string, error) {
|
||||
if !utf8.ValidString(name) {
|
||||
return "", fmt.Errorf("name is not valid UTF-8")
|
||||
}
|
||||
name = StripCtrlChars(name)
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("name is empty after sanitization")
|
||||
}
|
||||
if utf8.RuneCountInString(name) > maxProfileNameLen {
|
||||
return "", fmt.Errorf("name exceeds %d characters", maxProfileNameLen)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// StripCtrlChars control characters from a name before printing it.
|
||||
func StripCtrlChars(name string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(name))
|
||||
for _, r := range name {
|
||||
// Skip C0 controls and DEL, plus C1 controls (0x80–0x9F).
|
||||
if r < 0x20 || r == 0x7F || (r >= 0x80 && r <= 0x9F) {
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ShortID truncates an ID for display.
|
||||
func (id ID) ShortID() string {
|
||||
if id == DefaultProfileName {
|
||||
return DefaultProfileName
|
||||
}
|
||||
runes := []rune(id)
|
||||
if len(runes) <= shortIDLen {
|
||||
return id.String()
|
||||
}
|
||||
return string(runes[:shortIDLen])
|
||||
}
|
||||
|
||||
func (id ID) String() string {
|
||||
return string(id)
|
||||
}
|
||||
@@ -19,19 +19,41 @@ const (
|
||||
)
|
||||
|
||||
type Profile struct {
|
||||
Name string
|
||||
// ID is the on-disk filename stem (without .json). For new profiles
|
||||
// it is a 32-char hex string; legacy profiles created before the
|
||||
// ID-keyed layout keep their original name as their ID. The reserved
|
||||
// value "default" identifies the special default profile.
|
||||
ID ID
|
||||
// Name is the human-readable display name. Falls back to ID when the
|
||||
// underlying JSON has no "name" field set.
|
||||
Name string
|
||||
// Path is the absolute path to the profile JSON. Populated by the
|
||||
// loader so callers do not have to reconstruct it from ID + dir.
|
||||
Path string
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
func (p *Profile) FilePath() (string, error) {
|
||||
if p.Name == "" {
|
||||
return "", fmt.Errorf("active profile name is empty")
|
||||
if p.Path != "" {
|
||||
return p.Path, nil
|
||||
}
|
||||
|
||||
if p.Name == defaultProfileName {
|
||||
id := p.ID
|
||||
if id == "" {
|
||||
id = ID(p.Name)
|
||||
}
|
||||
if id == "" {
|
||||
return "", fmt.Errorf("profile ID is empty")
|
||||
}
|
||||
|
||||
if id == defaultProfileName {
|
||||
return DefaultConfigPath, nil
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(id) {
|
||||
return "", fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
username, err := user.Current()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current user: %w", err)
|
||||
@@ -42,10 +64,13 @@ func (p *Profile) FilePath() (string, error) {
|
||||
return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, p.Name+".json"), nil
|
||||
return filepath.Join(configDir, id.String()+".json"), nil
|
||||
}
|
||||
|
||||
func (p *Profile) IsDefault() bool {
|
||||
if p.ID != "" {
|
||||
return p.ID == defaultProfileName
|
||||
}
|
||||
return p.Name == defaultProfileName
|
||||
}
|
||||
|
||||
@@ -57,18 +82,24 @@ func NewProfileManager() *ProfileManager {
|
||||
return &ProfileManager{}
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the active profile as recorded in the local
|
||||
// user state file. Only ID is populated.
|
||||
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
prof := pm.getActiveProfileState()
|
||||
return &Profile{Name: prof}, nil
|
||||
id := pm.getActiveProfileState()
|
||||
return &Profile{ID: id}, nil
|
||||
}
|
||||
|
||||
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
// SwitchProfile records the given profile ID as active in the local user
|
||||
// state file.
|
||||
func (pm *ProfileManager) SwitchProfile(id ID) error {
|
||||
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
if err := pm.setActiveProfileState(profileName); err != nil {
|
||||
if err := pm.setActiveProfileState(id); err != nil {
|
||||
return fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -85,7 +116,7 @@ func sanitizeProfileName(name string) string {
|
||||
}, name)
|
||||
}
|
||||
|
||||
func (pm *ProfileManager) getActiveProfileState() string {
|
||||
func (pm *ProfileManager) getActiveProfileState() ID {
|
||||
|
||||
configDir, err := getConfigDir()
|
||||
if err != nil {
|
||||
@@ -113,10 +144,10 @@ func (pm *ProfileManager) getActiveProfileState() string {
|
||||
return defaultProfileName
|
||||
}
|
||||
|
||||
return profileName
|
||||
return ID(profileName)
|
||||
}
|
||||
|
||||
func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
||||
func (pm *ProfileManager) setActiveProfileState(id ID) error {
|
||||
|
||||
configDir, err := getConfigDir()
|
||||
if err != nil {
|
||||
@@ -125,7 +156,7 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
||||
|
||||
statePath := filepath.Join(configDir, activeProfileStateFilename)
|
||||
|
||||
err = os.WriteFile(statePath, []byte(profileName), 0600)
|
||||
err = os.WriteFile(statePath, []byte(id), 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write active profile state: %w", err)
|
||||
}
|
||||
@@ -142,7 +173,7 @@ func GetLoginHint() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
profileState, err := pm.GetProfileState(activeProf.ID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
return ""
|
||||
|
||||
@@ -50,14 +50,14 @@ func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) {
|
||||
|
||||
state, err := sm.GetActiveProfileState()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet
|
||||
assert.Equal(t, defaultProfileName, state.ID.String()) // No active profile state yet
|
||||
|
||||
err = sm.SetActiveProfileStateToDefault()
|
||||
assert.NoError(t, err)
|
||||
|
||||
active, err := sm.GetActiveProfileState()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "default", active.Name)
|
||||
assert.Equal(t, "default", active.ID.String())
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -92,14 +92,14 @@ func TestServiceManager_SetActiveProfileState(t *testing.T) {
|
||||
currUser, err := user.Current()
|
||||
assert.NoError(t, err)
|
||||
sm := &ServiceManager{}
|
||||
state := &ActiveProfileState{Name: "foo", Username: currUser.Username}
|
||||
state := &ActiveProfileState{ID: "foo", Username: currUser.Username}
|
||||
err = sm.SetActiveProfileState(state)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should error on nil or incomplete state
|
||||
err = sm.SetActiveProfileState(nil)
|
||||
assert.Error(t, err)
|
||||
err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""})
|
||||
err = sm.SetActiveProfileState(&ActiveProfileState{ID: "", Username: ""})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@ package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -23,12 +24,43 @@ var (
|
||||
DefaultConfigPathDir = ""
|
||||
DefaultConfigPath = ""
|
||||
ActiveProfileStatePath = ""
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorOldDefaultConfigNotFound = errors.New("old default config not found")
|
||||
)
|
||||
|
||||
// ErrAmbiguousHandle is returned when a profile handle (ID prefix or name)
|
||||
// matches more than one profile. Callers can render Candidates to help the
|
||||
// user disambiguate.
|
||||
type ErrAmbiguousHandle struct {
|
||||
Handle string
|
||||
Candidates []Profile
|
||||
Kind AmbiguityKind
|
||||
}
|
||||
|
||||
// AmbiguityKind describes which matcher produced the ambiguity, so callers
|
||||
// can tailor the error message.
|
||||
type AmbiguityKind int
|
||||
|
||||
const (
|
||||
AmbiguityKindIDPrefix AmbiguityKind = iota
|
||||
AmbiguityKindName
|
||||
)
|
||||
|
||||
// profileMeta is the minimal slice of a profile JSON we need, so we avoid
|
||||
// reading all fields
|
||||
type profileMeta struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (e *ErrAmbiguousHandle) Error() string {
|
||||
switch e.Kind {
|
||||
case AmbiguityKindIDPrefix:
|
||||
return fmt.Sprintf("ID prefix %q is ambiguous (matches %d profiles)", e.Handle, len(e.Candidates))
|
||||
default:
|
||||
return fmt.Sprintf("name %q is ambiguous (%d profiles share this name)", e.Handle, len(e.Candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
DefaultConfigPathDir = "/var/lib/netbird/"
|
||||
@@ -54,25 +86,34 @@ func init() {
|
||||
}
|
||||
|
||||
type ActiveProfileState struct {
|
||||
Name string `json:"name"`
|
||||
// ID is the on-disk filename stem of the active profile. The JSON tag stays
|
||||
// as "name" for backwards compatibility with active state files written
|
||||
// before the ID-based config files. Legacy values were profile names, which
|
||||
// were also the legacy filename stems, so they still resolve to the correct
|
||||
// file on disk.
|
||||
ID ID `json:"name"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
func (a *ActiveProfileState) FilePath() (string, error) {
|
||||
if a.Name == "" {
|
||||
return "", fmt.Errorf("active profile name is empty")
|
||||
if a.ID == "" {
|
||||
return "", fmt.Errorf("active profile ID is empty")
|
||||
}
|
||||
|
||||
if a.Name == defaultProfileName {
|
||||
if a.ID == defaultProfileName {
|
||||
return DefaultConfigPath, nil
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(a.ID) {
|
||||
return "", fmt.Errorf("invalid profile ID: %q", a.ID)
|
||||
}
|
||||
|
||||
configDir, err := getConfigDirForUser(a.Username)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, a.Name+".json"), nil
|
||||
return filepath.Join(configDir, a.ID.String()+".json"), nil
|
||||
}
|
||||
|
||||
type ServiceManager struct {
|
||||
@@ -178,7 +219,7 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
|
||||
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
|
||||
}
|
||||
return &ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: defaultProfileName,
|
||||
Username: "",
|
||||
}, nil
|
||||
} else {
|
||||
@@ -186,12 +227,12 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if activeProfile.Name == "" {
|
||||
if activeProfile.ID == "" {
|
||||
if err := s.SetActiveProfileStateToDefault(); err != nil {
|
||||
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
|
||||
}
|
||||
return &ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: defaultProfileName,
|
||||
Username: "",
|
||||
}, nil
|
||||
}
|
||||
@@ -216,25 +257,29 @@ func (s *ServiceManager) setDefaultActiveState() error {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error {
|
||||
if a == nil || a.Name == "" {
|
||||
if a == nil || a.ID == "" {
|
||||
return errors.New("invalid active profile state")
|
||||
}
|
||||
|
||||
if a.Name != defaultProfileName && a.Username == "" {
|
||||
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name)
|
||||
if a.ID != defaultProfileName && a.Username == "" {
|
||||
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.ID)
|
||||
}
|
||||
|
||||
if a.ID != defaultProfileName && !IsValidProfileFilenameStem(a.ID) {
|
||||
return fmt.Errorf("invalid profile ID: %q", a.ID)
|
||||
}
|
||||
|
||||
if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil {
|
||||
return fmt.Errorf("failed to write active profile state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("active profile set to %s for %s", a.Name, a.Username)
|
||||
log.Infof("active profile set to %s for %s", a.ID, a.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceManager) SetActiveProfileStateToDefault() error {
|
||||
return s.SetActiveProfileState(&ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: defaultProfileName,
|
||||
Username: "",
|
||||
})
|
||||
}
|
||||
@@ -243,57 +288,117 @@ func (s *ServiceManager) DefaultProfilePath() string {
|
||||
return DefaultConfigPath
|
||||
}
|
||||
|
||||
func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||
// AddProfile creates a new profile with a generated ID. The user-supplied
|
||||
// displayName is stored inside the JSON's name field, the on-disk filename
|
||||
// uses the generated ID.
|
||||
//
|
||||
// The returned Profile carries the freshly-generated ID so callers can
|
||||
// show it to the user (and so the gRPC AddProfileResponse can include
|
||||
// it).
|
||||
func (s *ServiceManager) AddProfile(displayName, username string) (*Profile, error) {
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config directory: %w", err)
|
||||
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
if profileName == defaultProfileName {
|
||||
return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
displayName, err = sanitizeDisplayName(displayName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if profileExists {
|
||||
return ErrProfileAlreadyExists
|
||||
return nil, fmt.Errorf("invalid profile name: %w", err)
|
||||
}
|
||||
|
||||
id, err := generateProfileID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate profile id: %w", err)
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, id.String()+".json")
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new config: %w", err)
|
||||
return nil, fmt.Errorf("failed to create new config: %w", err)
|
||||
}
|
||||
cfg.Name = displayName
|
||||
|
||||
if err := util.WriteJson(context.Background(), profPath, cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to write profile config: %w", err)
|
||||
}
|
||||
|
||||
err = util.WriteJson(context.Background(), profPath, cfg)
|
||||
return &Profile{
|
||||
ID: id,
|
||||
Name: displayName,
|
||||
Path: profPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ServiceManager) RenameProfile(id ID, username string, newName string) error {
|
||||
displayName, err := sanitizeDisplayName(newName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write profile config: %w", err)
|
||||
return fmt.Errorf("invalid profile name: %w", err)
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
profiles, err := s.loadAllProfiles(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load profiles: %w", err)
|
||||
}
|
||||
|
||||
var target *Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == id {
|
||||
target = &profiles[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if target == nil {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(target.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.Name = displayName
|
||||
|
||||
if err := util.WriteJson(context.Background(), target.Path, cfg); err != nil {
|
||||
return fmt.Errorf("failed to write profile name: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config directory: %w", err)
|
||||
// RemoveProfile deletes the profile identified by id. Callers must have
|
||||
// already resolved any user-supplied handle to a concrete ID via
|
||||
// ResolveProfile.
|
||||
func (s *ServiceManager) RemoveProfile(id ID, username string) error {
|
||||
if id == defaultProfileName {
|
||||
defaultName := readProfileName(DefaultConfigPath)
|
||||
if defaultName == "" {
|
||||
defaultName = defaultProfileName
|
||||
}
|
||||
return fmt.Errorf("cannot remove default profile with name: %s", defaultName)
|
||||
}
|
||||
if !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
if profileName == defaultProfileName {
|
||||
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
profiles, err := s.loadAllProfiles(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
return fmt.Errorf("load profiles: %w", err)
|
||||
}
|
||||
if !profileExists {
|
||||
|
||||
var target *Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == id {
|
||||
target = &profiles[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if target == nil {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
|
||||
@@ -301,57 +406,26 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
if err != nil && !errors.Is(err, ErrNoActiveProfile) {
|
||||
return fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
|
||||
if activeProf != nil && activeProf.Name == profileName {
|
||||
return fmt.Errorf("cannot remove active profile: %s", profileName)
|
||||
if activeProf != nil && activeProf.ID == id {
|
||||
return fmt.Errorf("cannot remove active profile: %s", id)
|
||||
}
|
||||
|
||||
err = util.RemoveJson(profPath)
|
||||
if err != nil {
|
||||
if err := util.RemoveJson(target.Path); err != nil {
|
||||
return fmt.Errorf("failed to remove profile config: %w", err)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(filepath.Dir(target.Path), id.String()+".state.json")
|
||||
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
|
||||
log.Warnf("failed to remove profile state file %s: %v", stateFile, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListProfiles returns every profile for the given user, including the
|
||||
// default profile, with IsActive flags set.
|
||||
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
|
||||
files, err := util.ListFiles(configDir, "*.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list profile files: %w", err)
|
||||
}
|
||||
|
||||
var filtered []string
|
||||
for _, file := range files {
|
||||
if strings.HasSuffix(file, "state.json") {
|
||||
continue // skip state files
|
||||
}
|
||||
filtered = append(filtered, file)
|
||||
}
|
||||
sort.Strings(filtered)
|
||||
|
||||
var activeProfName string
|
||||
activeProf, err := s.GetActiveProfileState()
|
||||
if err == nil {
|
||||
activeProfName = activeProf.Name
|
||||
}
|
||||
|
||||
var profiles []Profile
|
||||
// add default profile always
|
||||
profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName})
|
||||
for _, file := range filtered {
|
||||
profileName := strings.TrimSuffix(filepath.Base(file), ".json")
|
||||
var isActive bool
|
||||
if activeProfName != "" && activeProfName == profileName {
|
||||
isActive = true
|
||||
}
|
||||
profiles = append(profiles, Profile{Name: profileName, IsActive: isActive})
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
return s.loadAllProfiles(username)
|
||||
}
|
||||
|
||||
// GetStatePath returns the path to the state file based on the operating system
|
||||
@@ -369,7 +443,12 @@ func (s *ServiceManager) GetStatePath() string {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
if activeProf.Name == defaultProfileName {
|
||||
if activeProf.ID == defaultProfileName {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(activeProf.ID) {
|
||||
log.Warnf("invalid active profile ID %q, using default state path", activeProf.ID)
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
@@ -379,7 +458,7 @@ func (s *ServiceManager) GetStatePath() string {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, activeProf.Name+".state.json")
|
||||
return filepath.Join(configDir, activeProf.ID.String()+".state.json")
|
||||
}
|
||||
|
||||
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
|
||||
@@ -390,3 +469,169 @@ func (s *ServiceManager) getConfigDir(username string) (string, error) {
|
||||
|
||||
return getConfigDirForUser(username)
|
||||
}
|
||||
|
||||
// loadAllProfiles returns every profile visible to the daemon for the
|
||||
// given user, including the default profile. The returned slice is sorted
|
||||
// by ID for a stable display order.
|
||||
//
|
||||
// Each Profile is fully populated: ID is the filename stem, Name comes
|
||||
// from the JSON's "name" field (falling back to the filename stem when absent)
|
||||
// and Path is built from a basename read off disk.
|
||||
func (s *ServiceManager) loadAllProfiles(username string) ([]Profile, error) {
|
||||
activeID, activeIsDefault := s.activeProfileID()
|
||||
defaultName := readProfileName(DefaultConfigPath)
|
||||
if defaultName == "" {
|
||||
defaultName = defaultProfileName
|
||||
}
|
||||
|
||||
profiles := []Profile{{
|
||||
ID: defaultProfileName,
|
||||
Name: defaultName,
|
||||
Path: DefaultConfigPath,
|
||||
IsActive: activeIsDefault,
|
||||
}}
|
||||
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get config directory: %w", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(configDir)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return profiles, nil
|
||||
}
|
||||
return nil, fmt.Errorf("read profile directory: %w", err)
|
||||
}
|
||||
|
||||
var fileProfiles []Profile
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
base := entry.Name()
|
||||
if !strings.HasSuffix(base, ".json") {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(base, ".state.json") {
|
||||
continue
|
||||
}
|
||||
stem := ID(strings.TrimSuffix(base, ".json"))
|
||||
if stem == defaultProfileName {
|
||||
// default lives at the top-level config dir, not under /<user>
|
||||
continue
|
||||
}
|
||||
if !IsValidProfileFilenameStem(ID(stem)) {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(configDir, base)
|
||||
name := readProfileName(path)
|
||||
if name == "" {
|
||||
name = stem.String()
|
||||
}
|
||||
fileProfiles = append(fileProfiles, Profile{
|
||||
ID: stem,
|
||||
Name: name,
|
||||
Path: path,
|
||||
IsActive: stem == ID(activeID),
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(fileProfiles, func(i, j int) bool {
|
||||
if fileProfiles[i].Name != fileProfiles[j].Name {
|
||||
return fileProfiles[i].Name < fileProfiles[j].Name
|
||||
}
|
||||
// Sort tie-break on ID so duplicate names always render in the same order.
|
||||
return fileProfiles[i].ID < fileProfiles[j].ID
|
||||
})
|
||||
profiles = append(profiles, fileProfiles...)
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// readProfileName parses just the "name" field from the profile Json.
|
||||
func readProfileName(path string) string {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var meta profileMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
return ""
|
||||
}
|
||||
return meta.Name
|
||||
}
|
||||
|
||||
// activeProfileID returns the currently-active profile's ID. The second
|
||||
// return value is true when the active profile is the default one.
|
||||
func (s *ServiceManager) activeProfileID() (ID, bool) {
|
||||
state, err := s.GetActiveProfileState()
|
||||
if err != nil || state == nil {
|
||||
return defaultProfileName, true
|
||||
}
|
||||
if state.ID == "" || state.ID == defaultProfileName {
|
||||
return defaultProfileName, true
|
||||
}
|
||||
return state.ID, false
|
||||
}
|
||||
|
||||
// ResolveProfile turns a user-supplied handle into a Profile. Resolution
|
||||
// precedence is: exact ID match, then unique exact name, then unique ID
|
||||
// prefix. Ambiguous matches return *ErrAmbiguousHandle so callers can
|
||||
// surface the candidates.
|
||||
func (s *ServiceManager) ResolveProfile(handle, username string) (*Profile, error) {
|
||||
if handle == "" {
|
||||
return nil, fmt.Errorf("profile handle is empty")
|
||||
}
|
||||
|
||||
profiles, err := s.loadAllProfiles(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == ID(handle) {
|
||||
return &profiles[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
var nameMatches []Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].Name == handle {
|
||||
nameMatches = append(nameMatches, profiles[i])
|
||||
}
|
||||
}
|
||||
if len(nameMatches) == 1 {
|
||||
return &nameMatches[0], nil
|
||||
}
|
||||
if len(nameMatches) > 1 {
|
||||
return nil, &ErrAmbiguousHandle{
|
||||
Handle: handle,
|
||||
Candidates: nameMatches,
|
||||
Kind: AmbiguityKindName,
|
||||
}
|
||||
}
|
||||
|
||||
// ID prefix match. Skip the default profile so `select d` does not
|
||||
// accidentally pick it via prefix.
|
||||
var prefixMatches []Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == defaultProfileName {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(profiles[i].ID.String(), handle) {
|
||||
prefixMatches = append(prefixMatches, profiles[i])
|
||||
}
|
||||
}
|
||||
if len(prefixMatches) == 1 {
|
||||
return &prefixMatches[0], nil
|
||||
}
|
||||
if len(prefixMatches) > 1 {
|
||||
return nil, &ErrAmbiguousHandle{
|
||||
Handle: handle,
|
||||
Candidates: prefixMatches,
|
||||
Kind: AmbiguityKindIDPrefix,
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrProfileNotFound
|
||||
}
|
||||
|
||||
230
client/internal/profilemanager/service_test.go
Normal file
230
client/internal/profilemanager/service_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// withTestSM wires up patched globals + a clean config dir and returns a
|
||||
// fully initialized ServiceManager plus the username we are scoped to.
|
||||
func withTestSM(t *testing.T, fn func(sm *ServiceManager, username string)) {
|
||||
t.Helper()
|
||||
withTempConfigDir(t, func(configDir string) {
|
||||
withPatchedGlobals(t, configDir, func() {
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
sm := &ServiceManager{}
|
||||
require.NoError(t, sm.CreateDefaultProfile())
|
||||
fn(sm, u.Username)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_ExactID(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
created, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sm.ResolveProfile(created.ID.String(), username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
assert.Equal(t, "work", got.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_IDPrefix(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
created, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
prefix := created.ID[:4]
|
||||
got, err := sm.ResolveProfile(prefix.String(), username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_AmbiguousPrefix(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
// Plant two profiles whose IDs share a known prefix by writing
|
||||
// the files directly, since generated IDs are random.
|
||||
configDir, err := sm.getConfigDir(username)
|
||||
require.NoError(t, err)
|
||||
for _, id := range []string{"abcd1111aaaa", "abcd2222bbbb"} {
|
||||
path := filepath.Join(configDir, id+".json")
|
||||
require.NoError(t, util.WriteJson(context.Background(), path, &Config{Name: id}))
|
||||
}
|
||||
|
||||
_, err = sm.ResolveProfile("abcd", username)
|
||||
var amb *ErrAmbiguousHandle
|
||||
require.ErrorAs(t, err, &amb)
|
||||
assert.Equal(t, AmbiguityKindIDPrefix, amb.Kind)
|
||||
assert.Len(t, amb.Candidates, 2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_ExactNameUnique(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
_, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sm.ResolveProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "work", got.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_AmbiguousName(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
_, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
_, err = sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sm.ResolveProfile("work", username)
|
||||
var amb *ErrAmbiguousHandle
|
||||
require.ErrorAs(t, err, &amb)
|
||||
assert.Equal(t, AmbiguityKindName, amb.Kind)
|
||||
assert.Len(t, amb.Candidates, 2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_NotFound(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
_, err := sm.ResolveProfile("nope", username)
|
||||
assert.ErrorIs(t, err, ErrProfileNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_DefaultByExactID(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
got, err := sm.ResolveProfile(defaultProfileName, username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, defaultProfileName, got.ID.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_LegacyFilenameCoexists(t *testing.T) {
|
||||
// Legacy profiles stored as <name>.json with no "name" JSON field
|
||||
// should still be discoverable by name and removable by name.
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
configDir, err := sm.getConfigDir(username)
|
||||
require.NoError(t, err)
|
||||
path := filepath.Join(configDir, "legacy.json")
|
||||
require.NoError(t, util.WriteJson(context.Background(), path, &Config{}))
|
||||
|
||||
got, err := sm.ResolveProfile("legacy", username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "legacy", got.ID.String())
|
||||
// Name falls back to the filename stem when JSON omits it.
|
||||
assert.Equal(t, "legacy", got.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddProfile_AllowsDuplicateWithFlag(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
first, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
second, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, first.ID, second.ID)
|
||||
assert.Equal(t, "work", second.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddProfile_RejectsInvalidNames(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
cases := []string{
|
||||
"", // empty
|
||||
"\x00\x01", // only control chars (becomes empty)
|
||||
strings.Repeat("a", maxProfileNameLen+1), // too long
|
||||
}
|
||||
for _, name := range cases {
|
||||
_, err := sm.AddProfile(name, username)
|
||||
assert.Error(t, err, "expected error for %q", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemoveProfile_RejectsInvalidID(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
err := sm.RemoveProfile("../escape", username)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeDisplayName(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"work", "work", false},
|
||||
{"My Work Account", "My Work Account", false},
|
||||
{"emoji 🚀 ok", "emoji 🚀 ok", false},
|
||||
{"漢字テスト", "漢字テスト", false},
|
||||
{"with\x00null", "withnull", false},
|
||||
{"\x01\x02\x03", "", true},
|
||||
{"", "", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got, err := sanitizeDisplayName(tc.in)
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err, "case %q", tc.in)
|
||||
continue
|
||||
}
|
||||
assert.NoError(t, err, "case %q", tc.in)
|
||||
assert.Equal(t, tc.want, got, "case %q", tc.in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidProfileFilenameStem(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"default", true},
|
||||
{"abc123def456", true},
|
||||
{"legacy-name", true},
|
||||
{"legacy_name", true},
|
||||
{"", false},
|
||||
{"..", false},
|
||||
{"../etc", false},
|
||||
{"foo/bar", false},
|
||||
{`foo\bar`, false},
|
||||
{"with space", false},
|
||||
{"with.dot", false},
|
||||
{strings.Repeat("a", maxProfileIDLen+1), false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := IsValidProfileFilenameStem(ID(tc.in))
|
||||
assert.Equal(t, tc.want, got, "case %q", tc.in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveProfile_DeletesStateFile(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
created, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
configDir, err := sm.getConfigDir(username)
|
||||
require.NoError(t, err)
|
||||
statePath := filepath.Join(configDir, created.ID.String()+".state.json")
|
||||
require.NoError(t, os.WriteFile(statePath, []byte(`{"email":"a@b"}`), 0600))
|
||||
|
||||
require.NoError(t, sm.RemoveProfile(created.ID, username))
|
||||
_, err = os.Stat(statePath)
|
||||
assert.True(t, errors.Is(err, os.ErrNotExist), "state file should be removed")
|
||||
})
|
||||
}
|
||||
@@ -13,13 +13,20 @@ type ProfileState struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) {
|
||||
// GetProfileState reads the per-profile state file keyed by profile ID.
|
||||
// The state file lives in the user's config directory. Legacy state files
|
||||
// keyed by the old profile name remain readable.
|
||||
func (pm *ProfileManager) GetProfileState(id ID) (*ProfileState, error) {
|
||||
configDir, err := getConfigDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get config directory: %w", err)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, profileName+".state.json")
|
||||
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
|
||||
return nil, fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, id.String()+".state.json")
|
||||
stateFileExists, err := fileExists(stateFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
|
||||
@@ -51,7 +58,12 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
|
||||
return fmt.Errorf("get active profile: %w", err)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, activeProf.Name+".state.json")
|
||||
id := activeProf.ID
|
||||
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid active profile ID: %q", id)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, id.String()+".state.json")
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write profile state: %w", err)
|
||||
|
||||
@@ -32,6 +32,9 @@ type ProbeResult struct {
|
||||
URI string
|
||||
Err error
|
||||
Addr string
|
||||
// Transport is the negotiated relay transport, empty
|
||||
// for stun/turn probes or when not connected.
|
||||
Transport string
|
||||
}
|
||||
|
||||
type StunTurnProbe struct {
|
||||
|
||||
@@ -22,14 +22,14 @@ type removePeerCall struct {
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mu sync.Mutex
|
||||
addCalls []addPeerCall
|
||||
removed []removePeerCall
|
||||
nextID rp.PeerID
|
||||
addErr error
|
||||
removeErr error
|
||||
closed bool
|
||||
ran bool
|
||||
mu sync.Mutex
|
||||
addCalls []addPeerCall
|
||||
removed []removePeerCall
|
||||
nextID rp.PeerID
|
||||
addErr error
|
||||
removeErr error
|
||||
closed bool
|
||||
ran bool
|
||||
}
|
||||
|
||||
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
|
||||
@@ -51,7 +51,7 @@ func (m *mockServer) RemovePeer(id rp.PeerID) error {
|
||||
return m.removeErr
|
||||
}
|
||||
|
||||
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||
func (m *mockServer) Close() error { m.closed = true; return nil }
|
||||
|
||||
type setPSKCall struct {
|
||||
|
||||
@@ -41,4 +41,3 @@ func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
|
||||
_, err = DeterministicSeedKey(long, short)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -332,6 +333,8 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
}
|
||||
}
|
||||
|
||||
m.notifier.Close()
|
||||
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.clientRoutes = nil
|
||||
@@ -700,6 +703,8 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
m.mirrorV6ExitPairSelections(clientRoutes)
|
||||
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
@@ -716,6 +721,24 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
|
||||
m.logExitNodeUpdate(exitNodeInfo)
|
||||
}
|
||||
|
||||
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
|
||||
// consistent with its v4 base. The v4/v6 exit pair is a single toggle, so the v6
|
||||
// entry always follows the base: deselecting the v4 exit node also drops its ::/0
|
||||
// pair, and any stale (orphaned) explicit selection on the v6 entry is reset. This
|
||||
// runs before selection is read so both collectExitNodeInfo and FilterSelectedExitNodes
|
||||
// see consistent state, including pairs loaded from persisted selector state.
|
||||
func (m *DefaultManager) mirrorV6ExitPairSelections(clientRoutes route.HAMap) {
|
||||
routesByNetID := make(map[route.NetID][]*route.Route, len(clientRoutes))
|
||||
for haID, routes := range clientRoutes {
|
||||
routesByNetID[haID.NetID()] = routes
|
||||
}
|
||||
|
||||
for v6ID := range route.V6ExitMergeSet(routesByNetID) {
|
||||
baseID := route.NetID(strings.TrimSuffix(string(v6ID), route.V6ExitSuffix))
|
||||
m.routeSelector.SyncPairedSelection(baseID, v6ID)
|
||||
}
|
||||
}
|
||||
|
||||
type exitNodeInfo struct {
|
||||
allIDs []route.NetID
|
||||
selectedByManagement []route.NetID
|
||||
|
||||
47
client/internal/routemanager/manager_v6exit_test.go
Normal file
47
client/internal/routemanager/manager_v6exit_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair reproduces the bug seen
|
||||
// in netbird-engine.log: persisted selector state has the v4 exit node deselected
|
||||
// but its synthesized "-v6" pair explicitly selected (orphaned), so the ::/0 route
|
||||
// leaked onto the tunnel. The management update must mirror the v4 deselect onto the
|
||||
// v6 pair so FilterSelectedExitNodes drops it.
|
||||
func TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair(t *testing.T) {
|
||||
const (
|
||||
v4ID = route.NetID("Exit Node (raspberrypi)")
|
||||
v6ID = route.NetID("Exit Node (raspberrypi)-v6")
|
||||
)
|
||||
all := []route.NetID{v4ID, v6ID}
|
||||
|
||||
rs := routeselector.NewRouteSelector()
|
||||
// Orphan the v6 selection: select the pair, then deselect only the v4 base.
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{v4ID, v6ID}, true, all))
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{v4ID}, all))
|
||||
require.True(t, rs.IsSelected(v6ID), "precondition: orphaned v6 selection survives v4 deselect")
|
||||
|
||||
m := &DefaultManager{routeSelector: rs}
|
||||
|
||||
v4Route := &route.Route{NetID: v4ID, Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: v6ID, Network: netip.MustParsePrefix("::/0")}
|
||||
clientRoutes := route.HAMap{
|
||||
"Exit Node (raspberrypi)|0.0.0.0/0": {v4Route},
|
||||
"Exit Node (raspberrypi)-v6|::/0": {v6Route},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(clientRoutes)
|
||||
|
||||
assert.False(t, rs.IsSelected(v6ID), "v6 pair must follow the v4 base deselect after the management update")
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(clientRoutes)
|
||||
assert.Empty(t, filtered, "deselected v4 exit node must not leak its ::/0 pair onto the tunnel")
|
||||
}
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
type Notifier struct {
|
||||
initialRoutes []*route.Route
|
||||
currentRoutes []*route.Route
|
||||
fakeIPRoutes []*route.Route
|
||||
fakeIPRoutes []*route.Route
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
@@ -119,3 +119,7 @@ func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
sort.Strings(initialStrings)
|
||||
return initialStrings
|
||||
}
|
||||
|
||||
func (n *Notifier) Close() {
|
||||
// unused
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -14,19 +15,26 @@ import (
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
currentPrefixes []string
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
listener listener.NetworkChangeListener
|
||||
queue *list.List
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
n := &Notifier{
|
||||
queue: list.New(),
|
||||
}
|
||||
n.cond = sync.NewCond(&n.mu)
|
||||
go n.deliverLoop()
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
@@ -43,32 +51,52 @@ func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
newNets := make([]string, 0)
|
||||
newNets := make([]string, 0, len(prefixes))
|
||||
for _, prefix := range prefixes {
|
||||
newNets = append(newNets, prefix.String())
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
|
||||
n.mu.Lock()
|
||||
if slices.Equal(n.currentPrefixes, newNets) {
|
||||
n.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
n.currentPrefixes = newNets
|
||||
n.notify()
|
||||
routes := strings.Join(n.currentPrefixes, ",")
|
||||
n.queue.PushBack(routes)
|
||||
n.cond.Signal()
|
||||
n.mu.Unlock()
|
||||
}
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(n.currentPrefixes, ","))
|
||||
}(n.listener)
|
||||
func (n *Notifier) Close() {
|
||||
n.mu.Lock()
|
||||
n.closed = true
|
||||
n.cond.Signal()
|
||||
n.mu.Unlock()
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Notifier) deliverLoop() {
|
||||
for {
|
||||
n.mu.Lock()
|
||||
for n.queue.Len() == 0 && !n.closed {
|
||||
n.cond.Wait()
|
||||
}
|
||||
if n.closed && n.queue.Len() == 0 {
|
||||
n.mu.Unlock()
|
||||
return
|
||||
}
|
||||
routes := n.queue.Remove(n.queue.Front()).(string)
|
||||
l := n.listener
|
||||
n.mu.Unlock()
|
||||
|
||||
if l != nil {
|
||||
l.OnNetworkChanged(routes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,3 +38,7 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
func (n *Notifier) Close() {
|
||||
// unused
|
||||
}
|
||||
|
||||
@@ -121,9 +121,12 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Check if the prefix is part of any local subnets
|
||||
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
||||
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
||||
// BSDs blackhole a /32 added inside a directly-connected subnet; Linux/Windows need it to beat the wt0 route.
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "freebsd", "netbsd", "openbsd", "dragonfly":
|
||||
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
||||
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -132,6 +131,33 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
return rs.isSelectedLocked(routeID)
|
||||
}
|
||||
|
||||
// SyncPairedSelection forces pairedID's explicit selection state to match baseID's,
|
||||
// so a synthesized "-v6" exit route always follows its v4 base: selecting or
|
||||
// deselecting the v4 exit node governs the ::/0 pair, and any stale (orphaned)
|
||||
// explicit state on the v6 entry is reset. The v4/v6 exit pair is treated as a single
|
||||
// toggle, so the v6 entry carries no independent selection of its own.
|
||||
func (rs *RouteSelector) SyncPairedSelection(baseID, pairedID route.NetID) {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
if rs.deselectAll {
|
||||
return
|
||||
}
|
||||
|
||||
_, baseSelected := rs.selectedRoutes[baseID]
|
||||
_, baseDeselected := rs.deselectedRoutes[baseID]
|
||||
|
||||
delete(rs.selectedRoutes, pairedID)
|
||||
delete(rs.deselectedRoutes, pairedID)
|
||||
|
||||
switch {
|
||||
case baseSelected:
|
||||
rs.selectedRoutes[pairedID] = struct{}{}
|
||||
case baseDeselected:
|
||||
rs.deselectedRoutes[pairedID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
rs.mu.RLock()
|
||||
@@ -151,14 +177,13 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
}
|
||||
|
||||
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
|
||||
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
|
||||
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
|
||||
// transparently apply to the synthesized v6 entry.
|
||||
// The lookup is literal; v4/v6 exit pairs are kept consistent at write time via SyncPairedSelection,
|
||||
// so a synthesized "-v6" entry carries the same explicit state as its v4 base.
|
||||
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
|
||||
return rs.hasUserSelectionForRouteLocked(routeID)
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
|
||||
@@ -187,83 +212,6 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
|
||||
return filtered
|
||||
}
|
||||
|
||||
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
|
||||
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
|
||||
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
|
||||
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
|
||||
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
|
||||
name := string(id)
|
||||
if !strings.HasSuffix(name, route.V6ExitSuffix) {
|
||||
return id
|
||||
}
|
||||
if _, ok := rs.selectedRoutes[id]; ok {
|
||||
return id
|
||||
}
|
||||
if _, ok := rs.deselectedRoutes[id]; ok {
|
||||
return id
|
||||
}
|
||||
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return false
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return !deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return true
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[netID]
|
||||
return deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return selected || deselected
|
||||
}
|
||||
|
||||
func isExitNode(rt []*route.Route) bool {
|
||||
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) applyExitNodeFilter(
|
||||
id route.HAUniqueID,
|
||||
netID route.NetID,
|
||||
rt []*route.Route,
|
||||
out route.HAMap,
|
||||
) {
|
||||
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
|
||||
// drops the synthesized v6 entry that lacks its own explicit state.
|
||||
effective := rs.effectiveNetID(netID)
|
||||
if rs.hasUserSelectionForRouteLocked(effective) {
|
||||
if rs.isSelectedLocked(effective) {
|
||||
out[id] = rt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// no explicit selection for this route: defer to management's SkipAutoApply flag
|
||||
sel := collectSelected(rt)
|
||||
if len(sel) > 0 {
|
||||
out[id] = sel
|
||||
}
|
||||
}
|
||||
|
||||
func collectSelected(rt []*route.Route) []*route.Route {
|
||||
var sel []*route.Route
|
||||
for _, r := range rt {
|
||||
if !r.SkipAutoApply {
|
||||
sel = append(sel, r)
|
||||
}
|
||||
}
|
||||
return sel
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface
|
||||
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
|
||||
rs.mu.RLock()
|
||||
@@ -317,3 +265,59 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return false
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return !deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
|
||||
if rs.deselectAll {
|
||||
return true
|
||||
}
|
||||
_, deselected := rs.deselectedRoutes[netID]
|
||||
return deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
return selected || deselected
|
||||
}
|
||||
|
||||
func (rs *RouteSelector) applyExitNodeFilter(
|
||||
id route.HAUniqueID,
|
||||
netID route.NetID,
|
||||
rt []*route.Route,
|
||||
out route.HAMap,
|
||||
) {
|
||||
if rs.hasUserSelectionForRouteLocked(netID) {
|
||||
if rs.isSelectedLocked(netID) {
|
||||
out[id] = rt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// no explicit selection for this route: defer to management's SkipAutoApply flag
|
||||
sel := collectSelected(rt)
|
||||
if len(sel) > 0 {
|
||||
out[id] = sel
|
||||
}
|
||||
}
|
||||
|
||||
func isExitNode(rt []*route.Route) bool {
|
||||
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
|
||||
}
|
||||
|
||||
func collectSelected(rt []*route.Route) []*route.Route {
|
||||
var sel []*route.Route
|
||||
for _, r := range rt {
|
||||
if !r.SkipAutoApply {
|
||||
sel = append(sel, r)
|
||||
}
|
||||
}
|
||||
return sel
|
||||
}
|
||||
|
||||
@@ -330,39 +330,73 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
|
||||
assert.Len(t, filtered, 0) // No routes should be selected
|
||||
}
|
||||
|
||||
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
|
||||
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
|
||||
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
|
||||
// v4 base, so legacy persisted selections that predate v6 pairing transparently
|
||||
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
|
||||
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
|
||||
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
|
||||
// TestRouteSelector_V6ExitPairSync covers SyncPairedSelection, which keeps a v4
|
||||
// exit node and its synthesized "-v6" counterpart consistent. The selector itself
|
||||
// is literal and never infers a v6 entry's state from its v4 base; callers that know
|
||||
// the pairing (exit-node code paths) call SyncPairedSelection to force the v6 entry
|
||||
// to follow the base, treating the pair as a single toggle.
|
||||
func TestRouteSelector_V6ExitPairSync(t *testing.T) {
|
||||
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
|
||||
|
||||
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
|
||||
t.Run("selector lookups stay literal without sync", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
|
||||
// The selector does not pair-resolve: the v6 entry is independent until synced.
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 entry has no state of its own")
|
||||
assert.True(t, rs.IsSelected("exit1-v6"), "unsynced v6 entry stays selected by default")
|
||||
|
||||
// unrelated v6 with no v4 base touched is unaffected
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
|
||||
// A route literally named "exit1-something" must never pair-resolve either.
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
|
||||
})
|
||||
|
||||
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||
|
||||
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
|
||||
// via the mirror; the mirror only applies in exit-node code paths.
|
||||
assert.False(t, rs.IsSelected("corp"))
|
||||
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
|
||||
})
|
||||
|
||||
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
|
||||
t.Run("sync mirrors deselected v4 base onto v6", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.False(t, rs.IsSelected("exit1"))
|
||||
assert.False(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base deselect")
|
||||
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 carries explicit deselect after sync")
|
||||
})
|
||||
|
||||
t.Run("sync mirrors selected v4 base onto v6", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1"}, false, all))
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.True(t, rs.IsSelected("exit1"))
|
||||
assert.True(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base select")
|
||||
})
|
||||
|
||||
t.Run("sync clears v6 state when base has no explicit selection", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
|
||||
require.True(t, rs.HasUserSelectionForRoute("exit1-v6"))
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"),
|
||||
"v6 explicit state is cleared so it follows management like its base")
|
||||
})
|
||||
|
||||
// Regression for the observed bug (see netbird-engine.log): persisted state has
|
||||
// the v4 base deselected but the v6 sibling explicitly selected (orphaned). The
|
||||
// sync must reset the orphan so the ::/0 route does not leak onto the tunnel.
|
||||
t.Run("sync clears orphaned explicit v6 selection on deselected base", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
// Prior state: both explicitly selected, then only the v4 base deselected,
|
||||
// leaving the v6 entry as a stale explicit selection.
|
||||
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1", "exit1-v6"}, true, all))
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
require.True(t, rs.IsSelected("exit1-v6"), "precondition: orphaned v6 selection")
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.False(t, rs.IsSelected("exit1-v6"), "orphaned v6 selection reset to follow v4 deselect")
|
||||
|
||||
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||
@@ -370,23 +404,14 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
|
||||
"exit1|0.0.0.0/0": {v4Route},
|
||||
"exit1-v6|::/0": {v6Route},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelectedExitNodes(routes)
|
||||
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
|
||||
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
|
||||
assert.Empty(t, filtered, "deselecting v4 base must drop the v6 pair even if it was explicitly selected before")
|
||||
})
|
||||
|
||||
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
|
||||
// A route literally named "exit1-something" must not pair-resolve.
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
|
||||
})
|
||||
|
||||
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
|
||||
t.Run("filter drops synced v6 pair of deselected v4 base", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
|
||||
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
|
||||
@@ -399,6 +424,15 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
|
||||
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
|
||||
})
|
||||
|
||||
t.Run("deselectAll makes sync a no-op", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
rs.DeselectAllRoutes()
|
||||
|
||||
rs.SyncPairedSelection("exit1", "exit1-v6")
|
||||
|
||||
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "sync must not write explicit state under deselectAll")
|
||||
})
|
||||
|
||||
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
types "github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -54,6 +56,7 @@ type selectRoute struct {
|
||||
Network netip.Prefix
|
||||
Domains domain.List
|
||||
Selected bool
|
||||
Status string
|
||||
extraNetworks []netip.Prefix
|
||||
}
|
||||
|
||||
@@ -65,6 +68,8 @@ func init() {
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
stateFile string
|
||||
cacheDir string
|
||||
logFilePath string
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
@@ -75,16 +80,21 @@ type Client struct {
|
||||
onHostDnsFn func([]string)
|
||||
dnsManager dns.IosDnsManager
|
||||
loginComplete bool
|
||||
connectClient *internal.ConnectClient
|
||||
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
|
||||
preloadedConfig *profilemanager.Config
|
||||
|
||||
stateMu sync.RWMutex
|
||||
connectClient *internal.ConnectClient
|
||||
config *profilemanager.Config
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
||||
func NewClient(cfgFile, stateFile, cacheDir, logFilePath, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
stateFile: stateFile,
|
||||
cacheDir: cacheDir,
|
||||
logFilePath: logFilePath,
|
||||
deviceName: deviceName,
|
||||
osName: osName,
|
||||
osVersion: osVersion,
|
||||
@@ -161,8 +171,13 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, connectClient)
|
||||
// Persist the latest sync response so DebugBundle can include the network
|
||||
// map. On iOS this is backed by disk to keep it out of the constrained
|
||||
// process memory (see the syncstore package).
|
||||
connectClient.SetSyncResponsePersistence(true)
|
||||
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -174,6 +189,84 @@ func (c *Client) Stop() {
|
||||
}
|
||||
|
||||
c.ctxCancel()
|
||||
c.setState(nil, nil)
|
||||
}
|
||||
|
||||
// DebugBundle generates a debug bundle, uploads it and returns the upload key.
|
||||
// It works with or without a running engine: when the engine is up it reuses
|
||||
// the live config, sync response and client metrics; otherwise it loads the
|
||||
// config from disk (or the preloaded tvOS config).
|
||||
func (c *Client) DebugBundle(anonymize bool) (string, error) {
|
||||
cfg, cc := c.stateSnapshot()
|
||||
|
||||
// If the engine hasn't been started, load config so we can reach management.
|
||||
if cfg == nil {
|
||||
if c.preloadedConfig != nil {
|
||||
cfg = c.preloadedConfig
|
||||
} else {
|
||||
var err error
|
||||
// Use DirectUpdateOrCreateConfig to avoid atomic file operations
|
||||
// (temp file + rename) blocked by the tvOS sandbox.
|
||||
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
StateFilePath: c.stateFile,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deps := debug.GeneratorDependencies{
|
||||
InternalConfig: cfg,
|
||||
StatusRecorder: c.recorder,
|
||||
TempDir: c.cacheDir,
|
||||
StatePath: c.stateFile,
|
||||
LogPath: c.logFilePath,
|
||||
}
|
||||
|
||||
if cc != nil {
|
||||
resp, err := cc.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
log.Warnf("get latest sync response: %v", err)
|
||||
}
|
||||
deps.SyncResponse = resp
|
||||
|
||||
if e := cc.Engine(); e != nil {
|
||||
if cm := e.GetClientMetrics(); cm != nil {
|
||||
deps.ClientMetrics = cm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
deps,
|
||||
debug.BundleConfig{
|
||||
Anonymize: anonymize,
|
||||
IncludeSystemInfo: true,
|
||||
},
|
||||
)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("debug bundle uploaded with key %s", key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SetTraceLogLevel configure the logger to trace level
|
||||
@@ -227,6 +320,16 @@ func (c *Client) RemoveConnectionListener() {
|
||||
c.recorder.RemoveConnectionListener()
|
||||
}
|
||||
|
||||
// IsLoginRequiredCached reports whether the LAST observed management error was an
|
||||
// auth failure (PermissionDenied/InvalidArgument), using the in-memory status
|
||||
// recorder. Unlike IsLoginRequired() it performs NO network call, so it is safe to
|
||||
// call from the connection listener during teardown (e.g. onDisconnected) without
|
||||
// blocking on a slow or unavailable network. Returns false while connected to
|
||||
// management or when the last error was not auth-related.
|
||||
func (c *Client) IsLoginRequiredCached() bool {
|
||||
return c.recorder.IsLoginRequired()
|
||||
}
|
||||
|
||||
func (c *Client) IsLoginRequired() bool {
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
@@ -354,11 +457,12 @@ func (c *Client) ClearLoginComplete() {
|
||||
}
|
||||
|
||||
func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
if c.connectClient == nil {
|
||||
_, connectClient := c.stateSnapshot()
|
||||
if connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := c.connectClient.Engine()
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -377,9 +481,57 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
// Compute each route's connection status in the core (mirroring the Android
|
||||
// bridge), so the UI doesn't have to infer it by string-matching the joined
|
||||
// Network value against peer routes. For a merged exit node the status reflects
|
||||
// whichever of the v4/v6 prefixes is served by a connected peer; for dynamic
|
||||
// (DNS) routes the peer route key is the domain pattern (see dynamic.Route.String).
|
||||
connectedRoutes := c.connectedRouteSet()
|
||||
for _, r := range routes {
|
||||
r.Status = routeStatus(r, connectedRoutes)
|
||||
}
|
||||
|
||||
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
|
||||
}
|
||||
|
||||
// connectedRouteSet returns the set of route keys (as strings) currently served by a
|
||||
// connected peer, gathered across all connected peers' route tables. The keys match
|
||||
// what the route manager records: a prefix string for static routes (e.g. "0.0.0.0/0")
|
||||
// and the domain pattern for dynamic routes (e.g. "*.example.com").
|
||||
func (c *Client) connectedRouteSet() map[string]struct{} {
|
||||
connected := map[string]struct{}{}
|
||||
for _, p := range c.recorder.GetFullStatus().Peers {
|
||||
if p.ConnStatus != peer.StatusConnected {
|
||||
continue
|
||||
}
|
||||
for r := range p.GetRoutes() {
|
||||
connected[r] = struct{}{}
|
||||
}
|
||||
}
|
||||
return connected
|
||||
}
|
||||
|
||||
// routeStatus reports "Connected" if any of the route's keys is served by a connected
|
||||
// peer: the primary Network prefix, an extra v6 network of a merged exit node, or the
|
||||
// domain pattern for a dynamic DNS route. Otherwise "Idle".
|
||||
func routeStatus(r *selectRoute, connectedRoutes map[string]struct{}) string {
|
||||
keys := make([]string, 0, 1+len(r.extraNetworks))
|
||||
if len(r.Domains) > 0 {
|
||||
keys = append(keys, r.Domains.SafeString())
|
||||
} else {
|
||||
keys = append(keys, r.Network.String())
|
||||
}
|
||||
for _, extra := range r.extraNetworks {
|
||||
keys = append(keys, extra.String())
|
||||
}
|
||||
for _, k := range keys {
|
||||
if _, ok := connectedRoutes[k]; ok {
|
||||
return peer.StatusConnected.String()
|
||||
}
|
||||
}
|
||||
return peer.StatusIdle.String()
|
||||
}
|
||||
|
||||
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
|
||||
var routes []*selectRoute
|
||||
for id, rt := range routesMap {
|
||||
@@ -462,6 +614,7 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
||||
Network: netStr,
|
||||
Domains: &domainDetails,
|
||||
Selected: r.Selected,
|
||||
Status: r.Status,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -470,11 +623,12 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
||||
}
|
||||
|
||||
func (c *Client) SelectRoute(id string) error {
|
||||
if c.connectClient == nil {
|
||||
_, connectClient := c.stateSnapshot()
|
||||
if connectClient == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := c.connectClient.Engine()
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -500,10 +654,11 @@ func (c *Client) SelectRoute(id string) error {
|
||||
}
|
||||
|
||||
func (c *Client) DeselectRoute(id string) error {
|
||||
if c.connectClient == nil {
|
||||
_, connectClient := c.stateSnapshot()
|
||||
if connectClient == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
engine := c.connectClient.Engine()
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -527,6 +682,22 @@ func (c *Client) DeselectRoute(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setState stores the running engine state so DebugBundle can reuse the live
|
||||
// config and ConnectClient. It is cleared on Stop.
|
||||
func (c *Client) setState(cfg *profilemanager.Config, cc *internal.ConnectClient) {
|
||||
c.stateMu.Lock()
|
||||
defer c.stateMu.Unlock()
|
||||
c.config = cfg
|
||||
c.connectClient = cc
|
||||
}
|
||||
|
||||
// stateSnapshot returns the current config and ConnectClient under the lock.
|
||||
func (c *Client) stateSnapshot() (*profilemanager.Config, *internal.ConnectClient) {
|
||||
c.stateMu.RLock()
|
||||
defer c.stateMu.RUnlock()
|
||||
return c.config, c.connectClient
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
ds := d.String()
|
||||
dotIndex := strings.Index(ds, ".")
|
||||
|
||||
@@ -20,6 +20,7 @@ type RoutesSelectionInfo struct {
|
||||
Network string
|
||||
Domains *DomainDetails
|
||||
Selected bool
|
||||
Status string
|
||||
}
|
||||
|
||||
type DomainCollection interface {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -85,6 +85,8 @@ service DaemonService {
|
||||
|
||||
rpc AddProfile(AddProfileRequest) returns (AddProfileResponse) {}
|
||||
|
||||
rpc RenameProfile(RenameProfileRequest) returns (RenameProfileResponse) {}
|
||||
|
||||
rpc RemoveProfile(RemoveProfileRequest) returns (RemoveProfileResponse) {}
|
||||
|
||||
rpc ListProfiles(ListProfilesRequest) returns (ListProfilesResponse) {}
|
||||
@@ -378,6 +380,9 @@ message RelayState {
|
||||
string URI = 1;
|
||||
bool available = 2;
|
||||
string error = 3;
|
||||
// transport is the negotiated relay transport (e.g. "ws", "quic"),
|
||||
// empty for stun/turn probes or when not connected.
|
||||
string transport = 4;
|
||||
}
|
||||
|
||||
message NSGroupState {
|
||||
@@ -622,11 +627,18 @@ message GetEventsResponse {
|
||||
}
|
||||
|
||||
message SwitchProfileRequest {
|
||||
// profileName is treated as a handle: exact ID, unique ID prefix, or
|
||||
// unique display name. The daemon resolves it server-side.
|
||||
optional string profileName = 1;
|
||||
optional string username = 2;
|
||||
}
|
||||
|
||||
message SwitchProfileResponse {}
|
||||
message SwitchProfileResponse {
|
||||
// id is the resolved on-disk ID of the profile that became active.
|
||||
// Lets CLI clients update their local active-profile state without
|
||||
// duplicating the resolution logic.
|
||||
string id = 1;
|
||||
}
|
||||
|
||||
message SetConfigRequest {
|
||||
string username = 1;
|
||||
@@ -693,17 +705,42 @@ message SetConfigResponse{}
|
||||
|
||||
message AddProfileRequest {
|
||||
string username = 1;
|
||||
// profileName carries the human-readable display name for the new
|
||||
// profile. The on-disk filename is a separately-generated ID.
|
||||
string profileName = 2;
|
||||
}
|
||||
|
||||
message AddProfileResponse {}
|
||||
message AddProfileResponse {
|
||||
// id is the generated on-disk ID of the new profile. CLI clients
|
||||
// display a truncated form, UI clients can ignore it.
|
||||
string id = 1;
|
||||
}
|
||||
|
||||
message RenameProfileRequest {
|
||||
string username = 1;
|
||||
// handle: an exact ID, a unique ID prefix, or a unique display name.
|
||||
string handle = 2;
|
||||
// newProfileName is the new human-readable display name for the profile.
|
||||
string newProfileName = 3;
|
||||
}
|
||||
|
||||
message RenameProfileResponse {
|
||||
// confirm the old profile name after resolving handle.
|
||||
string oldProfileName = 1;
|
||||
}
|
||||
|
||||
message RemoveProfileRequest {
|
||||
string username = 1;
|
||||
// profileName is treated as a handle: an exact ID, a unique ID
|
||||
// prefix, or a unique display name. Resolution happens server-side.
|
||||
string profileName = 2;
|
||||
}
|
||||
|
||||
message RemoveProfileResponse {}
|
||||
message RemoveProfileResponse {
|
||||
// id is the full resolved ID of the removed profile, so callers can
|
||||
// confirm exactly which profile a name/prefix handle resolved to.
|
||||
string id = 1;
|
||||
}
|
||||
|
||||
message ListProfilesRequest {
|
||||
string username = 1;
|
||||
@@ -716,6 +753,7 @@ message ListProfilesResponse {
|
||||
message Profile {
|
||||
string name = 1;
|
||||
bool is_active = 2;
|
||||
string id = 3;
|
||||
}
|
||||
|
||||
message GetActiveProfileRequest {}
|
||||
@@ -723,6 +761,7 @@ message GetActiveProfileRequest {}
|
||||
message GetActiveProfileResponse {
|
||||
string profileName = 1;
|
||||
string username = 2;
|
||||
string id = 3;
|
||||
}
|
||||
|
||||
message LogoutRequest {
|
||||
|
||||
@@ -45,6 +45,7 @@ const (
|
||||
DaemonService_SwitchProfile_FullMethodName = "/daemon.DaemonService/SwitchProfile"
|
||||
DaemonService_SetConfig_FullMethodName = "/daemon.DaemonService/SetConfig"
|
||||
DaemonService_AddProfile_FullMethodName = "/daemon.DaemonService/AddProfile"
|
||||
DaemonService_RenameProfile_FullMethodName = "/daemon.DaemonService/RenameProfile"
|
||||
DaemonService_RemoveProfile_FullMethodName = "/daemon.DaemonService/RemoveProfile"
|
||||
DaemonService_ListProfiles_FullMethodName = "/daemon.DaemonService/ListProfiles"
|
||||
DaemonService_GetActiveProfile_FullMethodName = "/daemon.DaemonService/GetActiveProfile"
|
||||
@@ -112,6 +113,7 @@ type DaemonServiceClient interface {
|
||||
SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error)
|
||||
SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error)
|
||||
AddProfile(ctx context.Context, in *AddProfileRequest, opts ...grpc.CallOption) (*AddProfileResponse, error)
|
||||
RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error)
|
||||
RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error)
|
||||
ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error)
|
||||
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
|
||||
@@ -422,6 +424,16 @@ func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequ
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RenameProfileResponse)
|
||||
err := c.cc.Invoke(ctx, DaemonService_RenameProfile_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RemoveProfileResponse)
|
||||
@@ -613,6 +625,7 @@ type DaemonServiceServer interface {
|
||||
SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error)
|
||||
SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error)
|
||||
AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error)
|
||||
RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error)
|
||||
RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error)
|
||||
ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error)
|
||||
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
|
||||
@@ -723,6 +736,9 @@ func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigReq
|
||||
func (UnimplementedDaemonServiceServer) AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method AddProfile not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method RenameProfile not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method RemoveProfile not implemented")
|
||||
}
|
||||
@@ -1237,6 +1253,24 @@ func _DaemonService_AddProfile_Handler(srv interface{}, ctx context.Context, dec
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_RenameProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RenameProfileRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).RenameProfile(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: DaemonService_RenameProfile_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).RenameProfile(ctx, req.(*RenameProfileRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_RemoveProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RemoveProfileRequest)
|
||||
if err := dec(in); err != nil {
|
||||
@@ -1567,6 +1601,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "AddProfile",
|
||||
Handler: _DaemonService_AddProfile_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RenameProfile",
|
||||
Handler: _DaemonService_RenameProfile_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RemoveProfile",
|
||||
Handler: _DaemonService_RemoveProfile_Handler,
|
||||
|
||||
@@ -79,7 +79,7 @@ func TestPersistLoginOverrides(t *testing.T) {
|
||||
_, err := profilemanager.UpdateOrCreateConfig(seed)
|
||||
require.NoError(t, err, "seed config")
|
||||
|
||||
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
|
||||
activeProf := &profilemanager.ActiveProfileState{ID: "default"}
|
||||
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
|
||||
require.NoError(t, err, "persistLoginOverrides")
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ type Server struct {
|
||||
// changed by connectWithRetryRuns goroutine exit — for that
|
||||
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
|
||||
// derives from clientGiveUpChan close state. Protected by s.mutex.
|
||||
clientRunning bool
|
||||
clientRunning bool
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||
|
||||
@@ -375,7 +375,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config, err := setConfigInputFromRequest(msg)
|
||||
config, err := s.setConfigInputFromRequest(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -398,17 +398,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
// field is its own optional case. Returns the resolved ConfigInput
|
||||
// and a non-nil error only when the active profile file path cannot
|
||||
// be determined.
|
||||
func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
|
||||
func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
|
||||
var config profilemanager.ConfigInput
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: msg.ProfileName,
|
||||
Username: msg.Username,
|
||||
}
|
||||
profPath, err := profState.FilePath()
|
||||
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return config, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
log.Errorf("failed to resolve profile %q: %v", msg.ProfileName, err)
|
||||
return config, err
|
||||
}
|
||||
profPath := resolved.Path
|
||||
if profPath == "" {
|
||||
profPath = profilemanager.DefaultConfigPath
|
||||
}
|
||||
config.ConfigPath = profPath
|
||||
|
||||
@@ -535,30 +535,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
|
||||
if msg.ProfileName != nil {
|
||||
if *msg.ProfileName != "default" && (msg.Username == nil || *msg.Username == "") {
|
||||
log.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
|
||||
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
|
||||
}
|
||||
|
||||
var username string
|
||||
if *msg.ProfileName != "default" {
|
||||
username = *msg.Username
|
||||
}
|
||||
|
||||
if *msg.ProfileName != activeProf.Name && username != activeProf.Username {
|
||||
if s.checkProfilesDisabled() {
|
||||
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username)
|
||||
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: *msg.ProfileName,
|
||||
Username: username,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to set active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to set active profile state: %w", err)
|
||||
}
|
||||
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,7 +547,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
|
||||
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
|
||||
|
||||
s.mutex.Lock()
|
||||
|
||||
@@ -806,10 +785,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
}
|
||||
|
||||
if msg != nil && msg.ProfileName != nil {
|
||||
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -820,7 +799,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
|
||||
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
|
||||
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
@@ -864,34 +843,60 @@ func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error {
|
||||
if profileName != "default" && (userName == nil || *userName == "") {
|
||||
log.Errorf("profile name is set to %s, but username is not provided", profileName)
|
||||
return fmt.Errorf("profile name is set to %s, but username is not provided", profileName)
|
||||
// resolveProfileHandle resolves a wire-level profile handle (display
|
||||
// name, ID, or unique ID prefix) to a concrete profile. Returns gRPC
|
||||
// status errors so handlers can return them directly.
|
||||
func (s *Server) resolveProfileHandle(handle, username string) (*profilemanager.Profile, error) {
|
||||
p, err := s.profileManager.ResolveProfile(handle, username)
|
||||
if err == nil {
|
||||
return p, nil
|
||||
}
|
||||
var amb *profilemanager.ErrAmbiguousHandle
|
||||
if errors.As(err, &amb) {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "%v", amb)
|
||||
}
|
||||
if errors.Is(err, profilemanager.ErrProfileNotFound) {
|
||||
return nil, gstatus.Errorf(codes.NotFound, "profile %q not found", handle)
|
||||
}
|
||||
return nil, fmt.Errorf("resolve profile: %w", err)
|
||||
}
|
||||
|
||||
// switchProfileIfNeeded resolves the user-supplied handle, updates the
|
||||
// active profile state if it differs from the current one, and returns
|
||||
// the resolved profile so callers can include its ID in RPC responses.
|
||||
func (s *Server) switchProfileIfNeeded(handle string, userName *string, activeProf *profilemanager.ActiveProfileState) (*profilemanager.Profile, error) {
|
||||
if handle != profilemanager.DefaultProfileName && (userName == nil || *userName == "") {
|
||||
log.Errorf("profile name is set to %s, but username is not provided", handle)
|
||||
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", handle)
|
||||
}
|
||||
|
||||
var username string
|
||||
if profileName != "default" {
|
||||
if handle != profilemanager.DefaultProfileName {
|
||||
username = *userName
|
||||
}
|
||||
|
||||
if profileName != activeProf.Name || username != activeProf.Username {
|
||||
resolved, err := s.resolveProfileHandle(handle, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resolved.ID != activeProf.ID || username != activeProf.Username {
|
||||
if s.checkProfilesDisabled() {
|
||||
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
|
||||
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
log.Infof("switching to profile %s for user %s", profileName, username)
|
||||
log.Infof("switching to profile %s (%s) for user %s", resolved.Name, resolved.ID, username)
|
||||
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
ID: resolved.ID,
|
||||
Username: username,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to set active profile state: %v", err)
|
||||
return fmt.Errorf("failed to set active profile state: %w", err)
|
||||
return nil, fmt.Errorf("failed to set active profile state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// SwitchProfile switches the active profile in the daemon.
|
||||
@@ -906,9 +911,9 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
}
|
||||
|
||||
if msg != nil && msg.ProfileName != nil {
|
||||
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
activeProf, err = s.profileManager.GetActiveProfileState()
|
||||
@@ -924,7 +929,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
|
||||
s.config = config
|
||||
|
||||
return &proto.SwitchProfileResponse{}, nil
|
||||
return &proto.SwitchProfileResponse{Id: activeProf.ID.String()}, nil
|
||||
}
|
||||
|
||||
// Down engine work in the daemon.
|
||||
@@ -1014,22 +1019,27 @@ func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.L
|
||||
}
|
||||
|
||||
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
|
||||
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msg.Username == nil || *msg.Username == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
|
||||
}
|
||||
username := *msg.Username
|
||||
|
||||
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
|
||||
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
|
||||
resolved, err := s.resolveProfileHandle(*msg.ProfileName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.validateProfileOperation(resolved.ID, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.logoutFromProfile(ctx, resolved); err != nil {
|
||||
log.Errorf("failed to logout from profile %s: %v", resolved.ID, err)
|
||||
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
|
||||
}
|
||||
|
||||
activeProf, _ := s.profileManager.GetActiveProfileState()
|
||||
if activeProf != nil && activeProf.Name == *msg.ProfileName {
|
||||
if activeProf != nil && activeProf.ID == resolved.ID {
|
||||
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
||||
log.Errorf("failed to cleanup connection: %v", err)
|
||||
}
|
||||
@@ -1091,30 +1101,30 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
|
||||
return config, configExisted, nil
|
||||
}
|
||||
|
||||
func (s *Server) canRemoveProfile(profileName string) error {
|
||||
if profileName == profilemanager.DefaultProfileName {
|
||||
func (s *Server) canRemoveProfile(id profilemanager.ID) error {
|
||||
if id == profilemanager.DefaultProfileName {
|
||||
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.Name == profileName {
|
||||
return fmt.Errorf("remove active profile: %s", profileName)
|
||||
if err == nil && activeProf.ID == id {
|
||||
return fmt.Errorf("remove active profile: %s", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
|
||||
func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfile bool) error {
|
||||
if s.checkProfilesDisabled() {
|
||||
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if profileName == "" {
|
||||
if id == "" {
|
||||
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
|
||||
}
|
||||
|
||||
if !allowActiveProfile {
|
||||
if err := s.canRemoveProfile(profileName); err != nil {
|
||||
if err := s.canRemoveProfile(id); err != nil {
|
||||
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
}
|
||||
@@ -1122,25 +1132,20 @@ func (s *Server) validateProfileOperation(profileName string, allowActiveProfile
|
||||
return nil
|
||||
}
|
||||
|
||||
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
|
||||
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
|
||||
return s.sendLogoutRequest(ctx)
|
||||
}
|
||||
|
||||
profileState := &profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
Username: username,
|
||||
}
|
||||
profilePath, err := profileState.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get profile path: %w", err)
|
||||
cfgPath := profile.Path
|
||||
if cfgPath == "" {
|
||||
cfgPath = profilemanager.DefaultConfigPath
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(profilePath)
|
||||
config, err := profilemanager.GetConfig(cfgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("profile '%s' not found", profileName)
|
||||
return fmt.Errorf("profile '%s' not found", profile.ID)
|
||||
}
|
||||
|
||||
return s.sendLogoutRequestWithConfig(ctx, config)
|
||||
@@ -1558,15 +1563,14 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
prof := profilemanager.ActiveProfileState{
|
||||
Name: req.ProfileName,
|
||||
Username: req.Username,
|
||||
}
|
||||
|
||||
cfgPath, err := prof.FilePath()
|
||||
resolved, err := s.resolveProfileHandle(req.ProfileName, req.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
log.Errorf("failed to resolve profile %q: %v", req.ProfileName, err)
|
||||
return nil, err
|
||||
}
|
||||
cfgPath := resolved.Path
|
||||
if cfgPath == "" {
|
||||
cfgPath = profilemanager.DefaultConfigPath
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.GetConfig(cfgPath)
|
||||
@@ -1671,12 +1675,39 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
|
||||
}
|
||||
|
||||
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
created, err := s.profileManager.AddProfile(msg.ProfileName, msg.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to create profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.AddProfileResponse{}, nil
|
||||
return &proto.AddProfileResponse{Id: created.ID.String()}, nil
|
||||
}
|
||||
|
||||
func (s *Server) RenameProfile(ctx context.Context, msg *proto.RenameProfileRequest) (*proto.RenameProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkProfilesDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if msg.Handle == "" || msg.Username == "" || msg.NewProfileName == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name, username and new profile name must be provided")
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProfileHandle(msg.Handle, msg.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.profileManager.RenameProfile(resolved.ID, msg.Username, msg.NewProfileName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to rename profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to rename profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RenameProfileResponse{OldProfileName: resolved.Name}, nil
|
||||
}
|
||||
|
||||
// RemoveProfile removes a profile from the daemon.
|
||||
@@ -1684,20 +1715,29 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
|
||||
if s.checkProfilesDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if msg.ProfileName == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
|
||||
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
|
||||
if err := s.logoutFromProfile(ctx, resolved); err != nil {
|
||||
log.Warnf("failed to logout from profile %s before removal: %v", resolved.ID, err)
|
||||
}
|
||||
|
||||
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
if err := s.profileManager.RemoveProfile(resolved.ID, msg.Username); err != nil {
|
||||
log.Errorf("failed to remove profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RemoveProfileResponse{}, nil
|
||||
return &proto.RemoveProfileResponse{Id: resolved.ID.String()}, nil
|
||||
}
|
||||
|
||||
// ListProfiles lists all profiles in the daemon.
|
||||
@@ -1720,6 +1760,7 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
|
||||
}
|
||||
for i, profile := range profiles {
|
||||
response.Profiles[i] = &proto.Profile{
|
||||
Id: profile.ID.String(),
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
}
|
||||
@@ -1728,7 +1769,9 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the active profile in the daemon.
|
||||
// GetActiveProfile returns the active profile in the daemon. The ProfileName
|
||||
// field carries the display name for backwards compatibility with UI clients,
|
||||
// new callers should prefer Id.
|
||||
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
@@ -1739,9 +1782,23 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
// Fallback to legacy name == ID
|
||||
displayName := activeProfile.ID.String()
|
||||
if activeProfile.ID != profilemanager.DefaultProfileName {
|
||||
if profiles, lerr := s.profileManager.ListProfiles(activeProfile.Username); lerr == nil {
|
||||
for _, p := range profiles {
|
||||
if p.ID == activeProfile.ID {
|
||||
displayName = p.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.GetActiveProfileResponse{
|
||||
ProfileName: activeProfile.Name,
|
||||
ProfileName: displayName,
|
||||
Username: activeProfile.Username,
|
||||
Id: activeProfile.ID.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "test-profile",
|
||||
ID: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -158,7 +158,7 @@ func TestServer_Up(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
ID: profilemanager.ID(profName),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -228,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: "default",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -62,7 +62,7 @@ func setupServerWithProfile(t *testing.T) (s *Server, ctx context.Context, profN
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
require.NoError(t, pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
ID: profilemanager.ID(profName),
|
||||
Username: currUser.Username,
|
||||
}))
|
||||
|
||||
@@ -107,9 +107,9 @@ func TestSetConfig_MDMReject_SingleField(t *testing.T) {
|
||||
|
||||
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
ID: profilemanager.ID(profName),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -96,7 +96,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
DisableNotifications: &disableNotifications,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableIpv6: &disableIPv6,
|
||||
DisableIpv6: &disableIPv6,
|
||||
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
|
||||
CleanNATExternalIPs: false,
|
||||
CustomDNSAddress: []byte("1.1.1.1:53"),
|
||||
@@ -112,7 +112,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
ID: profilemanager.ID(profName),
|
||||
Username: currUser.Username,
|
||||
}
|
||||
cfgPath, err := profState.FilePath()
|
||||
|
||||
@@ -98,6 +98,7 @@ type RelayStateOutputDetail struct {
|
||||
URI string `json:"uri" yaml:"uri"`
|
||||
Available bool `json:"available" yaml:"available"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
Transport string `json:"transport,omitempty" yaml:"transport,omitempty"`
|
||||
}
|
||||
|
||||
type RelayStateOutput struct {
|
||||
@@ -219,7 +220,8 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
|
||||
RelayStateOutputDetail{
|
||||
URI: relay.URI,
|
||||
Available: available,
|
||||
Error: relay.GetError(),
|
||||
Error: relayErrorString(relay.GetError()),
|
||||
Transport: relay.GetTransport(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -235,6 +237,12 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
|
||||
}
|
||||
}
|
||||
|
||||
// relayErrorString flattens a newline-joined aggregated relay error onto a
|
||||
// single line for status output.
|
||||
func relayErrorString(s string) string {
|
||||
return strings.ReplaceAll(s, "\n", "; ")
|
||||
}
|
||||
|
||||
func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
|
||||
mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers))
|
||||
for _, pbNsGroupServer := range servers {
|
||||
@@ -441,6 +449,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
} else if relay.Transport != "" {
|
||||
available = fmt.Sprintf("%s via %s", available, relay.Transport)
|
||||
}
|
||||
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
|
||||
@@ -647,3 +647,13 @@ func TestTimeAgo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapRelaysTransport(t *testing.T) {
|
||||
out := mapRelays([]*proto.RelayState{
|
||||
{URI: "rels://relay.example:443", Available: true, Transport: "quic"},
|
||||
{URI: "rels://relay2.example:443", Available: true, Transport: "ws"},
|
||||
})
|
||||
require.Len(t, out.Details, 2)
|
||||
assert.Equal(t, "quic", out.Details[0].Transport)
|
||||
assert.Equal(t, "ws", out.Details[1].Transport)
|
||||
}
|
||||
|
||||
@@ -645,7 +645,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
|
||||
}
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
}
|
||||
|
||||
@@ -818,13 +818,15 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
|
||||
return nil, fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
handle := activeProf.ID.String()
|
||||
|
||||
loginReq := &proto.LoginRequest{
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
ProfileName: &activeProf.Name,
|
||||
ProfileName: &handle,
|
||||
Username: &currUser.Username,
|
||||
}
|
||||
|
||||
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
|
||||
profileState, err := s.profileManager.GetProfileState(activeProf.ID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
@@ -1367,7 +1369,7 @@ func (s *serviceClient) getSrvConfig() {
|
||||
}
|
||||
|
||||
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1613,7 +1615,7 @@ func (s *serviceClient) loadSettings() {
|
||||
}
|
||||
|
||||
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1813,7 +1815,7 @@ func (s *serviceClient) updateConfig() error {
|
||||
}
|
||||
|
||||
req := proto.SetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
|
||||
@@ -66,7 +66,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
} else {
|
||||
indicator.SetText("")
|
||||
}
|
||||
nameLabel.SetText(profile.Name)
|
||||
nameLabel.SetText(formatProfileLabel(profile, profiles))
|
||||
|
||||
// Configure Select/Active button
|
||||
selectBtn.SetText(func() string {
|
||||
@@ -88,7 +88,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
return
|
||||
}
|
||||
// switch
|
||||
err = s.switchProfile(profile.Name)
|
||||
err = s.switchProfile(profile.ID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
dialog.ShowError(errors.New("failed to select profile"), s.wProfiles)
|
||||
@@ -130,7 +130,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
logoutBtn.Show()
|
||||
logoutBtn.SetText("Deregister")
|
||||
logoutBtn.OnTapped = func() {
|
||||
s.handleProfileLogout(profile.Name, refresh)
|
||||
s.handleProfileLogout(profile, refresh)
|
||||
}
|
||||
|
||||
// Remove profile
|
||||
@@ -144,7 +144,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
return
|
||||
}
|
||||
|
||||
err = s.removeProfile(profile.Name)
|
||||
err = s.removeProfile(profile.ID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to remove profile: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("failed to remove profile"), s.wProfiles)
|
||||
@@ -250,7 +250,7 @@ func (s *serviceClient) addProfile(profileName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) switchProfile(profileName string) error {
|
||||
func (s *serviceClient) switchProfile(handle string) error {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf(getClientFMT, err)
|
||||
@@ -261,15 +261,15 @@ func (s *serviceClient) switchProfile(profileName string) error {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profileName,
|
||||
resp, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &handle,
|
||||
Username: &currUser.Username,
|
||||
}); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile failed: %w", err)
|
||||
}
|
||||
|
||||
err = s.profileManager.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
if err := s.profileManager.SwitchProfile(profilemanager.ID(resp.Id)); err != nil {
|
||||
return fmt.Errorf("switch profile: %w", err)
|
||||
}
|
||||
|
||||
@@ -299,10 +299,27 @@ func (s *serviceClient) removeProfile(profileName string) error {
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
ID string
|
||||
Name string
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
// formatProfileLabel returns the display label for a profile. Profiles can
|
||||
// share the same Name, so when more than one profile in profiles carries this
|
||||
// Name, a short form of the ID is appended to disambiguate the entries.
|
||||
func formatProfileLabel(profile Profile, profiles []Profile) string {
|
||||
count := 0
|
||||
for _, p := range profiles {
|
||||
if p.Name == profile.Name {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count <= 1 {
|
||||
return profile.Name
|
||||
}
|
||||
return fmt.Sprintf("%s (%s)", profile.Name, profilemanager.ID(profile.ID).ShortID())
|
||||
}
|
||||
|
||||
func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
@@ -324,6 +341,7 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
|
||||
for _, profile := range profilesResp.Profiles {
|
||||
profiles = append(profiles, Profile{
|
||||
ID: profile.Id,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
})
|
||||
@@ -332,10 +350,10 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
|
||||
func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback func()) {
|
||||
dialog.ShowConfirm(
|
||||
"Deregister",
|
||||
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
|
||||
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profile.Name),
|
||||
func(confirm bool) {
|
||||
if !confirm {
|
||||
return
|
||||
@@ -356,8 +374,10 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
|
||||
}
|
||||
|
||||
username := currUser.Username
|
||||
// ProfileName is treated as a handle; send the ID so the
|
||||
// daemon resolves to exactly this profile.
|
||||
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
|
||||
ProfileName: &profileName,
|
||||
ProfileName: &profile.ID,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -368,7 +388,7 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
|
||||
|
||||
dialog.ShowInformation(
|
||||
"Deregistered",
|
||||
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
|
||||
fmt.Sprintf("Successfully deregistered from '%s'", profile.Name),
|
||||
s.wProfiles,
|
||||
)
|
||||
|
||||
@@ -461,6 +481,7 @@ func (p *profileMenu) getProfiles() ([]Profile, error) {
|
||||
|
||||
for _, profile := range profilesResp.Profiles {
|
||||
profiles = append(profiles, Profile{
|
||||
ID: profile.Id,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
})
|
||||
@@ -501,7 +522,7 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
|
||||
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
|
||||
activeProfState, err := p.profileManager.GetProfileState(activeProf.ProfileName)
|
||||
activeProfState, err := p.profileManager.GetProfileState(profilemanager.ID(activeProf.Id))
|
||||
if err != nil {
|
||||
log.Warnf("failed to get active profile state: %v", err)
|
||||
p.emailMenuItem.Hide()
|
||||
@@ -512,7 +533,7 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
|
||||
for _, profile := range profiles {
|
||||
item := p.profileMenuItem.AddSubMenuItem(profile.Name, "")
|
||||
item := p.profileMenuItem.AddSubMenuItem(formatProfileLabel(profile, profiles), "")
|
||||
if profile.IsActive {
|
||||
item.Check()
|
||||
}
|
||||
@@ -541,8 +562,8 @@ func (p *profileMenu) refresh() {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profile.Name,
|
||||
switchResp, err := conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profile.ID,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -552,7 +573,7 @@ func (p *profileMenu) refresh() {
|
||||
return
|
||||
}
|
||||
|
||||
err = p.profileManager.SwitchProfile(profile.Name)
|
||||
err = p.profileManager.SwitchProfile(profilemanager.ID(switchResp.Id))
|
||||
if err != nil {
|
||||
log.Errorf("failed to switch profile '%s': %v", profile.Name, err)
|
||||
return
|
||||
@@ -727,7 +748,10 @@ func (p *profileMenu) updateMenu() {
|
||||
}
|
||||
|
||||
sort.Slice(profiles, func(i, j int) bool {
|
||||
return profiles[i].Name < profiles[j].Name
|
||||
if profiles[i].Name != profiles[j].Name {
|
||||
return profiles[i].Name < profiles[j].Name
|
||||
}
|
||||
return profiles[i].ID < profiles[j].ID
|
||||
})
|
||||
|
||||
p.mu.Lock()
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
nbwebsocket "github.com/netbirdio/netbird/client/wasm/internal/websocket"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -30,6 +31,7 @@ const (
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
dialWebSocketTimeout = 30 * time.Second
|
||||
|
||||
icmpEchoRequest = 8
|
||||
icmpCodeEcho = 0
|
||||
@@ -677,6 +679,7 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["dialWebSocket"] = createDialWebSocketMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
@@ -691,6 +694,74 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
func createDialWebSocketMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
url, protocols, timeout, errVal := parseDialWebSocketArgs(args)
|
||||
if !errVal.IsUndefined() {
|
||||
return errVal
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := nbwebsocket.Dial(ctx, client, url, protocols)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("dial websocket: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
resolve.Invoke(nbwebsocket.NewJSInterface(conn))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func parseDialWebSocketArgs(args []js.Value) (url string, protocols []string, timeout time.Duration, errVal js.Value) {
|
||||
if len(args) < 1 || args[0].Type() != js.TypeString {
|
||||
return "", nil, 0, js.ValueOf("error: dialWebSocket requires a URL string argument")
|
||||
}
|
||||
url = args[0].String()
|
||||
|
||||
if len(args) >= 2 && !args[1].IsNull() && !args[1].IsUndefined() {
|
||||
arr, err := jsStringArray(args[1])
|
||||
if err != nil {
|
||||
return "", nil, 0, js.ValueOf(fmt.Sprintf("error: protocols: %v", err))
|
||||
}
|
||||
protocols = arr
|
||||
}
|
||||
|
||||
timeout = dialWebSocketTimeout
|
||||
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
|
||||
if args[2].Type() != js.TypeNumber {
|
||||
return "", nil, 0, js.ValueOf("error: timeoutMs must be a number")
|
||||
}
|
||||
timeoutMs := args[2].Int()
|
||||
if timeoutMs <= 0 {
|
||||
return "", nil, 0, js.ValueOf("error: timeout must be positive")
|
||||
}
|
||||
timeout = time.Duration(timeoutMs) * time.Millisecond
|
||||
}
|
||||
|
||||
return url, protocols, timeout, js.Undefined()
|
||||
}
|
||||
|
||||
// jsStringArray converts a JS array of strings to a Go []string.
|
||||
func jsStringArray(v js.Value) ([]string, error) {
|
||||
if !v.InstanceOf(js.Global().Get("Array")) {
|
||||
return nil, fmt.Errorf("expected array")
|
||||
}
|
||||
n := v.Length()
|
||||
out := make([]string, n)
|
||||
for i := 0; i < n; i++ {
|
||||
el := v.Index(i)
|
||||
if el.Type() != js.TypeString {
|
||||
return nil, fmt.Errorf("element %d is not a string", i)
|
||||
}
|
||||
out[i] = el.String()
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
|
||||
304
client/wasm/internal/websocket/websocket.go
Normal file
304
client/wasm/internal/websocket/websocket.go
Normal file
@@ -0,0 +1,304 @@
|
||||
//go:build js
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type closeError struct {
|
||||
code uint16
|
||||
reason string
|
||||
}
|
||||
|
||||
func (e *closeError) Error() string {
|
||||
return fmt.Sprintf("websocket closed: %d %s", e.code, e.reason)
|
||||
}
|
||||
|
||||
// bufferedConn fronts a net.Conn with a reader that serves any bytes buffered
|
||||
// during the WebSocket handshake before falling through to the raw conn.
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) { return c.r.Read(p) }
|
||||
|
||||
// Conn wraps a WebSocket connection over a NetBird TCP connection.
|
||||
type Conn struct {
|
||||
conn net.Conn
|
||||
mu sync.Mutex
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// Dial establishes a WebSocket connection to the given URL through the NetBird network.
|
||||
// Optional protocols are sent via the Sec-WebSocket-Protocol header.
|
||||
func Dial(ctx context.Context, client *netbird.Client, rawURL string, protocols []string) (*Conn, error) {
|
||||
d := ws.Dialer{
|
||||
NetDial: client.Dial,
|
||||
Protocols: protocols,
|
||||
}
|
||||
|
||||
conn, br, _, err := d.Dial(ctx, rawURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("websocket dial: %w", err)
|
||||
}
|
||||
|
||||
// br is non-nil when the server pushed frames alongside the handshake
|
||||
// response; those bytes live in the bufio.Reader and must be drained
|
||||
// before reading from conn, otherwise we'd skip the first frames.
|
||||
if br != nil {
|
||||
if br.Buffered() > 0 {
|
||||
conn = &bufferedConn{Conn: conn, r: io.MultiReader(br, conn)}
|
||||
} else {
|
||||
ws.PutReader(br)
|
||||
}
|
||||
}
|
||||
|
||||
return &Conn{
|
||||
conn: conn,
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadMessage reads the next WebSocket message, handling control frames automatically.
|
||||
func (c *Conn) ReadMessage() (ws.OpCode, []byte, error) {
|
||||
for {
|
||||
msgs, err := wsutil.ReadServerMessage(c.conn, nil)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
for _, msg := range msgs {
|
||||
if msg.OpCode.IsControl() {
|
||||
if err := c.handleControl(msg); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return msg.OpCode, msg.Payload, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) handleControl(msg wsutil.Message) error {
|
||||
switch msg.OpCode {
|
||||
case ws.OpPing:
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpPong, msg.Payload)
|
||||
case ws.OpClose:
|
||||
code, reason := parseClosePayload(msg.Payload)
|
||||
return &closeError{code: code, reason: reason}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteText sends a text WebSocket message.
|
||||
func (c *Conn) WriteText(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpText, data)
|
||||
}
|
||||
|
||||
// WriteBinary sends a binary WebSocket message.
|
||||
func (c *Conn) WriteBinary(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return wsutil.WriteClientMessage(c.conn, ws.OpBinary, data)
|
||||
}
|
||||
|
||||
// Close sends a close frame with StatusNormalClosure and closes the underlying connection.
|
||||
func (c *Conn) Close() error {
|
||||
return c.closeWith(ws.StatusNormalClosure, "")
|
||||
}
|
||||
|
||||
// closeWith sends a close frame with the given code/reason and closes the underlying connection.
|
||||
// Used to echo the server's code when responding to a server-initiated close per RFC 6455 §5.5.1.
|
||||
func (c *Conn) closeWith(code ws.StatusCode, reason string) error {
|
||||
var first bool
|
||||
c.closeOnce.Do(func() {
|
||||
first = true
|
||||
close(c.closed)
|
||||
|
||||
c.mu.Lock()
|
||||
_ = wsutil.WriteClientMessage(c.conn, ws.OpClose, ws.NewCloseFrameBody(code, reason))
|
||||
c.mu.Unlock()
|
||||
|
||||
c.closeErr = c.conn.Close()
|
||||
})
|
||||
|
||||
if !first {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
// NewJSInterface creates a JavaScript object wrapping the WebSocket connection.
|
||||
// It exposes: send(string|Uint8Array), close(), and callback properties
|
||||
// onmessage, onclose, onerror.
|
||||
//
|
||||
// Callback properties may be set from the JS thread while the read loop
|
||||
// goroutine reads them. In WASM this is safe because Go and JS share a
|
||||
// single thread, but the design would need synchronization on
|
||||
// multi-threaded runtimes.
|
||||
func NewJSInterface(conn *Conn) js.Value {
|
||||
obj := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
sendFunc := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
log.Errorf("websocket send requires a data argument")
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
switch data.Type() {
|
||||
case js.TypeString:
|
||||
if err := conn.WriteText([]byte(data.String())); err != nil {
|
||||
log.Errorf("failed to send websocket text: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
default:
|
||||
buf, err := jsToBytes(data)
|
||||
if err != nil {
|
||||
log.Errorf("failed to convert js value to bytes: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
if err := conn.WriteBinary(buf); err != nil {
|
||||
log.Errorf("failed to send websocket binary: %v", err)
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
}
|
||||
return js.ValueOf(true)
|
||||
})
|
||||
obj.Set("send", sendFunc)
|
||||
|
||||
closeFunc := js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close websocket: %v", err)
|
||||
}
|
||||
return js.Undefined()
|
||||
})
|
||||
obj.Set("close", closeFunc)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("close websocket on readLoop exit: %v", err)
|
||||
}
|
||||
}()
|
||||
readLoop(conn, obj)
|
||||
// Undefining before Release turns post-close JS calls into TypeError
|
||||
// instead of a silent "call to released function".
|
||||
obj.Set("send", js.Undefined())
|
||||
obj.Set("close", js.Undefined())
|
||||
sendFunc.Release()
|
||||
closeFunc.Release()
|
||||
}()
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
func jsToBytes(data js.Value) ([]byte, error) {
|
||||
var uint8Array js.Value
|
||||
switch {
|
||||
case data.InstanceOf(js.Global().Get("Uint8Array")):
|
||||
uint8Array = data
|
||||
case data.InstanceOf(js.Global().Get("ArrayBuffer")):
|
||||
uint8Array = js.Global().Get("Uint8Array").New(data)
|
||||
default:
|
||||
return nil, fmt.Errorf("send: unsupported data type, use string, Uint8Array, or ArrayBuffer")
|
||||
}
|
||||
|
||||
buf := make([]byte, uint8Array.Get("length").Int())
|
||||
js.CopyBytesToGo(buf, uint8Array)
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func readLoop(conn *Conn, obj js.Value) {
|
||||
var ce *closeError
|
||||
defer func() { invokeOnClose(obj, ce) }()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-conn.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
op, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
ce = handleReadError(conn, obj, err)
|
||||
return
|
||||
}
|
||||
|
||||
dispatchMessage(obj, op, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func handleReadError(conn *Conn, obj js.Value, err error) *closeError {
|
||||
var ce *closeError
|
||||
if errors.As(err, &ce) {
|
||||
if cerr := conn.closeWith(ws.StatusCode(ce.code), ce.reason); cerr != nil {
|
||||
log.Debugf("failed to close websocket after server close frame: %v", cerr)
|
||||
}
|
||||
return ce
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
if onerror := obj.Get("onerror"); onerror.Truthy() {
|
||||
onerror.Invoke(js.ValueOf(err.Error()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func invokeOnClose(obj js.Value, ce *closeError) {
|
||||
onclose := obj.Get("onclose")
|
||||
if !onclose.Truthy() {
|
||||
return
|
||||
}
|
||||
if ce != nil {
|
||||
onclose.Invoke(js.ValueOf(int(ce.code)), js.ValueOf(ce.reason))
|
||||
return
|
||||
}
|
||||
onclose.Invoke()
|
||||
}
|
||||
|
||||
func dispatchMessage(obj js.Value, op ws.OpCode, payload []byte) {
|
||||
onmessage := obj.Get("onmessage")
|
||||
if !onmessage.Truthy() {
|
||||
return
|
||||
}
|
||||
switch op {
|
||||
case ws.OpText:
|
||||
onmessage.Invoke(js.ValueOf(string(payload)))
|
||||
case ws.OpBinary:
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(payload))
|
||||
js.CopyBytesToJS(uint8Array, payload)
|
||||
onmessage.Invoke(uint8Array)
|
||||
}
|
||||
}
|
||||
|
||||
func parseClosePayload(payload []byte) (uint16, string) {
|
||||
if len(payload) < 2 {
|
||||
return 1005, "" // RFC 6455: No Status Rcvd
|
||||
}
|
||||
code := binary.BigEndian.Uint16(payload[:2])
|
||||
return code, string(payload[2:])
|
||||
}
|
||||
@@ -2,4 +2,5 @@ FROM ubuntu:24.04
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-server" ]
|
||||
CMD ["--config", "/etc/netbird/config.yaml"]
|
||||
COPY netbird-server /go/bin/netbird-server
|
||||
ARG TARGETPLATFORM
|
||||
COPY ${TARGETPLATFORM}/netbird-server /go/bin/netbird-server
|
||||
|
||||
3
go.mod
3
go.mod
@@ -56,6 +56,7 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.4
|
||||
github.com/gobwas/ws v1.4.0
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
@@ -215,6 +216,8 @@ require (
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/go-webauthn/webauthn v0.16.4 // indirect
|
||||
github.com/go-webauthn/x v0.2.3 // indirect
|
||||
github.com/gobwas/httphead v0.1.0 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
|
||||
7
go.sum
7
go.sum
@@ -249,6 +249,12 @@ github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4C
|
||||
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
|
||||
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
|
||||
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
|
||||
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
|
||||
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
|
||||
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
||||
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
|
||||
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
@@ -845,6 +851,7 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -41,6 +41,8 @@ type Config struct {
|
||||
GRPCAddr string
|
||||
}
|
||||
|
||||
const localConnectorID = "local"
|
||||
|
||||
// Provider wraps a Dex server
|
||||
type Provider struct {
|
||||
config *Config
|
||||
@@ -544,7 +546,7 @@ func (p *Provider) CreateUser(ctx context.Context, email, username, password str
|
||||
|
||||
// Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id})
|
||||
// This matches the format Dex uses in JWT tokens
|
||||
encodedID := EncodeDexUserID(userID, "local")
|
||||
encodedID := EncodeDexUserID(userID, localConnectorID)
|
||||
return encodedID, nil
|
||||
}
|
||||
|
||||
@@ -619,6 +621,13 @@ func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) {
|
||||
return userID, connectorID, nil
|
||||
}
|
||||
|
||||
// IsLocalUserID reports whether encodedID is a Dex subject for the built-in
|
||||
// local password connector.
|
||||
func IsLocalUserID(encodedID string) bool {
|
||||
_, connectorID, err := DecodeDexUserID(encodedID)
|
||||
return err == nil && connectorID == localConnectorID
|
||||
}
|
||||
|
||||
// GetUser returns a user by email
|
||||
func (p *Provider) GetUser(ctx context.Context, email string) (storage.Password, error) {
|
||||
return p.storage.GetPassword(ctx, email)
|
||||
|
||||
@@ -115,6 +115,26 @@ func TestDecodeDexUserID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalUserID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
encodedID string
|
||||
want bool
|
||||
}{
|
||||
{name: "local connector", encodedID: EncodeDexUserID("7aad8c05-3287-473f-b42a-365504bf25e7", "local"), want: true},
|
||||
{name: "federated connector", encodedID: EncodeDexUserID("entra-user", "entra"), want: false},
|
||||
{name: "non-dex external IdP id", encodedID: "google-oauth2|1234567890", want: false},
|
||||
{name: "invalid base64", encodedID: "not-valid-base64!!!", want: false},
|
||||
{name: "empty", encodedID: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, IsLocalUserID(tt.encodedID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDexUserID(t *testing.T) {
|
||||
userID := "7aad8c05-3287-473f-b42a-365504bf25e7"
|
||||
connectorID := "local"
|
||||
|
||||
@@ -2,4 +2,5 @@ FROM ubuntu:24.04
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
|
||||
CMD ["--log-file", "console"]
|
||||
COPY netbird-mgmt /go/bin/netbird-mgmt
|
||||
ARG TARGETPLATFORM
|
||||
COPY ${TARGETPLATFORM}/netbird-mgmt /go/bin/netbird-mgmt
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
FROM ubuntu:24.04
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
|
||||
CMD ["--log-file", "console"]
|
||||
COPY netbird-mgmt /go/bin/netbird-mgmt
|
||||
@@ -45,7 +45,7 @@ type Controller struct {
|
||||
EphemeralPeersManager ephemeral.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
sendAccountUpdateLocks sync.Map
|
||||
affectedPeerUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||
dnsDomain string
|
||||
@@ -64,6 +64,13 @@ type bufferUpdate struct {
|
||||
update atomic.Bool
|
||||
}
|
||||
|
||||
type bufferAffectedUpdate struct {
|
||||
sendMu sync.Mutex
|
||||
dataMu sync.Mutex
|
||||
next *time.Timer
|
||||
peerIDs map[string]struct{}
|
||||
}
|
||||
|
||||
var _ network_map.Controller = (*Controller)(nil)
|
||||
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
|
||||
@@ -201,7 +208,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
@@ -226,44 +233,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
@@ -273,6 +242,143 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
|
||||
return c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers updates only the specified peers that belong to an account.
|
||||
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
|
||||
|
||||
if !c.hasConnectedPeers(peerIDs) {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
|
||||
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
|
||||
if len(peersToUpdate) == 0 {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: sending network map to %d connected peers", len(peersToUpdate))
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get validate peers: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return fmt.Errorf("failed to get proxy network maps: %v", err)
|
||||
}
|
||||
|
||||
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get flow enabled status: %v", err)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return fmt.Errorf("failed to get account zones: %v", err)
|
||||
}
|
||||
|
||||
for _, peer := range peersToUpdate {
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
|
||||
for _, id := range peerIDs {
|
||||
if c.peersUpdateManager.HasChannel(id) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
|
||||
affected := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
affected[id] = struct{}{}
|
||||
}
|
||||
|
||||
var result []*nbpeer.Peer
|
||||
for _, peer := range account.Peers {
|
||||
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||
if !c.peersUpdateManager.HasChannel(peerId) {
|
||||
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
|
||||
@@ -381,66 +487,164 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
|
||||
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
})
|
||||
b := bufUpd.(*bufferAffectedUpdate)
|
||||
|
||||
b.addPeerIDs(peerIDs)
|
||||
|
||||
if !b.sendMu.TryLock() {
|
||||
// Another goroutine is already sending; it will pick up our IDs on its next drain.
|
||||
return nil
|
||||
}
|
||||
|
||||
b.stopTimer()
|
||||
|
||||
// The send and the debounced timer outlive the calling request, so detach from
|
||||
// its context to avoid sending with a cancelled context once the handler returns.
|
||||
bgCtx := context.WithoutCancel(ctx)
|
||||
|
||||
collected := b.drainPeerIDs()
|
||||
go func() {
|
||||
defer b.sendMu.Unlock()
|
||||
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, collected)
|
||||
|
||||
// Check if more peer IDs accumulated while we were sending.
|
||||
if !b.hasPending() {
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule a debounced flush for the newly accumulated IDs.
|
||||
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
ids := b.drainPeerIDs()
|
||||
if len(ids) > 0 {
|
||||
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, ids)
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
|
||||
b.dataMu.Lock()
|
||||
for _, id := range ids {
|
||||
b.peerIDs[id] = struct{}{}
|
||||
}
|
||||
b.dataMu.Unlock()
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if len(b.peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
ids := make([]string, 0, len(b.peerIDs))
|
||||
for id := range b.peerIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
b.peerIDs = make(map[string]struct{})
|
||||
return ids
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) hasPending() bool {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
return len(b.peerIDs) > 0
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) stopTimer() {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(d, f)
|
||||
return
|
||||
}
|
||||
b.next.Reset(d)
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
emptyMap := &types.NetworkMap{
|
||||
Network: network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, nil, 0, nil
|
||||
return emptyMap, nil, 0, nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peerID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||
return networkMap, postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
@@ -578,21 +782,24 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -625,7 +832,11 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
|
||||
@@ -19,17 +19,19 @@ const (
|
||||
|
||||
type Controller interface {
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
|
||||
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
CountStreams() int
|
||||
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
|
||||
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
|
||||
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
|
||||
|
||||
@@ -57,6 +57,20 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers mocks base method.
|
||||
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
|
||||
}
|
||||
|
||||
// CountStreams mocks base method.
|
||||
func (m *MockController) CountStreams() int {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -113,21 +127,20 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap mocks base method.
|
||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(int64)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, peerID)
|
||||
ret0, _ := ret[0].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].([]*posture.Checks)
|
||||
ret2, _ := ret[2].(int64)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeerConnected mocks base method.
|
||||
@@ -158,45 +171,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
|
||||
}
|
||||
|
||||
// OnPeersAdded mocks base method.
|
||||
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersAdded indicates an expected call of OnPeersAdded.
|
||||
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// OnPeersDeleted mocks base method.
|
||||
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
|
||||
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// OnPeersUpdated mocks base method.
|
||||
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
|
||||
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
|
||||
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
|
||||
}
|
||||
|
||||
// StartWarmup mocks base method.
|
||||
@@ -250,3 +263,17 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers mocks base method.
|
||||
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
|
||||
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
@@ -242,7 +242,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
|
||||
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create proxy peer: %w", err)
|
||||
}
|
||||
|
||||
@@ -918,6 +918,10 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
|
||||
}
|
||||
|
||||
for _, svc := range services {
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, svc.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete service targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
@@ -1270,6 +1274,10 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
|
||||
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
|
||||
}
|
||||
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service: %w", err)
|
||||
}
|
||||
@@ -1319,6 +1327,10 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service: %w", err)
|
||||
}
|
||||
|
||||
@@ -458,6 +458,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
txMock.EXPECT().
|
||||
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
|
||||
Return(newEphemeralService(), nil)
|
||||
txMock.EXPECT().
|
||||
DeleteServiceTargets(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
txMock.EXPECT().
|
||||
DeleteService(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
@@ -560,6 +563,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
txMock.EXPECT().
|
||||
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
|
||||
Return(newEphemeralService(), nil)
|
||||
txMock.EXPECT().
|
||||
DeleteServiceTargets(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
txMock.EXPECT().
|
||||
DeleteService(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
@@ -604,6 +610,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
txMock.EXPECT().
|
||||
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
|
||||
Return(newEphemeralService(), nil)
|
||||
txMock.EXPECT().
|
||||
DeleteServiceTargets(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
txMock.EXPECT().
|
||||
DeleteService(ctx, accountID, serviceID).
|
||||
Return(nil)
|
||||
@@ -1192,6 +1201,67 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
|
||||
}
|
||||
|
||||
func TestDeleteExpiredPeerService_DeletesTargets(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
|
||||
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before reaping")
|
||||
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
err = mgr.deleteExpiredPeerService(ctx, testAccountID, testPeerID, svcID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "expired peer-exposed service should be deleted")
|
||||
s, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.NotFound, s.Type())
|
||||
|
||||
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, targets, 0, "orphaned target rows must be deleted when an expired peer-exposed service is reaped")
|
||||
}
|
||||
|
||||
func TestDeleteServiceFromPeer_DeletesTargets(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
|
||||
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before stopping")
|
||||
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.Error(t, err, "stopped peer-exposed service should be deleted")
|
||||
s, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.NotFound, s.Type())
|
||||
|
||||
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, targets, 0, "orphaned target rows must be deleted when a peer stops its exposed service")
|
||||
}
|
||||
|
||||
func TestValidateProtocolChange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -778,7 +778,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
|
||||
peer, network, postureChecks, enableSSH, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
|
||||
WireGuardPubKey: peerKey.String(),
|
||||
SSHKey: string(sshKey),
|
||||
Meta: peerMeta,
|
||||
@@ -792,7 +792,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, network, postureChecks, enableSSH)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
@@ -895,7 +895,7 @@ func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMess
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, network *types.Network, postureChecks []*posture.Checks, enableSSH bool) (*proto.LoginResponse, error) {
|
||||
var relayToken *Token
|
||||
var err error
|
||||
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
|
||||
@@ -914,7 +914,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
|
||||
PeerConfig: toPeerConfig(peer, network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, enableSSH),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -1588,7 +1589,10 @@ func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Contex
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
// Child accounts and PAT-authenticated requests do not sync JWT groups.
|
||||
// Embedded-Dex local users also skip sync because local password authentication
|
||||
// does not provide external IdP group claims.
|
||||
if userAuth.IsChild || userAuth.IsPAT || dex.IsLocalUserID(userAuth.UserId) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1890,7 +1894,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
@@ -2573,7 +2577,9 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
|
||||
changedPeerIDs := []string{peerID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -2664,7 +2670,9 @@ func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if updateNetworkMap {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil {
|
||||
changedPeerIDs := []string{peerID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return fmt.Errorf("notify network map controller: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -61,7 +62,7 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -69,7 +70,7 @@ type Manager interface {
|
||||
UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
|
||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
@@ -108,8 +109,8 @@ type Manager interface {
|
||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) // used by peer gRPC API
|
||||
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
@@ -128,6 +129,7 @@ type Manager interface {
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
dns "github.com/netbirdio/netbird/dns"
|
||||
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
activity "github.com/netbirdio/netbird/management/server/activity"
|
||||
affectedpeers "github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
idp "github.com/netbirdio/netbird/management/server/idp"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
posture "github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -79,14 +80,15 @@ func (mr *MockManagerMockRecorder) AccountExists(ctx, accountID interface{}) *go
|
||||
}
|
||||
|
||||
// AddPeer mocks base method.
|
||||
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddPeer", ctx, accountID, setupKey, userID, p, temporary)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].(*types.Network)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
ret3, _ := ret[3].(bool)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// AddPeer indicates an expected call of AddPeer.
|
||||
@@ -1288,14 +1290,15 @@ func (mr *MockManagerMockRecorder) ListUsers(ctx, accountID interface{}) *gomock
|
||||
}
|
||||
|
||||
// LoginPeer mocks base method.
|
||||
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoginPeer", ctx, login)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].(*types.Network)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
ret3, _ := ret[3].(bool)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// LoginPeer indicates an expected call of LoginPeer.
|
||||
@@ -1320,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
@@ -1637,6 +1640,18 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// ExpandAndUpdateAffected mocks base method.
|
||||
func (m *MockManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ExpandAndUpdateAffected", ctx, accountID, snap, change)
|
||||
}
|
||||
|
||||
// ExpandAndUpdateAffected indicates an expected call of ExpandAndUpdateAffected.
|
||||
func (mr *MockManagerMockRecorder) ExpandAndUpdateAffected(ctx, accountID, snap, change interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpandAndUpdateAffected", reflect.TypeOf((*MockManager)(nil).ExpandAndUpdateAffected), ctx, accountID, snap, change)
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks base method.
|
||||
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -83,7 +84,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account
|
||||
setupKey = key.Key
|
||||
}
|
||||
|
||||
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
|
||||
_, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
|
||||
if err != nil {
|
||||
t.Error("expected to add new peer successfully after creating new account, but failed", err)
|
||||
}
|
||||
@@ -723,6 +724,28 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
require.Equal(t, g2.Name, "group2", "group2 name should match")
|
||||
require.Equal(t, g2.Issued, types.GroupIssuedJWT, "group2 issued should match")
|
||||
})
|
||||
t.Run("local embedded-Dex user is skipped", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
localClaims := auth.UserAuth{
|
||||
AccountId: accountID,
|
||||
Domain: domain,
|
||||
UserId: dex.EncodeDexUserID("local-owner", "local"),
|
||||
Groups: []string{"group3", "group4"},
|
||||
}
|
||||
err = manager.SyncUserJWTGroups(context.Background(), localClaims)
|
||||
require.NoError(t, err, "sync should be a no-op for local users")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
for _, g := range account.Groups {
|
||||
require.NotEqual(t, "group3", g.Name, "local user JWT groups must not be synced")
|
||||
require.NotEqual(t, "group4", g.Name, "local user JWT groups must not be synced")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
@@ -1069,7 +1092,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
}, false)
|
||||
@@ -1133,7 +1156,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedUserID := userID
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
}, false)
|
||||
@@ -1481,7 +1504,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
|
||||
peerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
|
||||
}, false)
|
||||
@@ -1803,7 +1826,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
@@ -1813,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1859,7 +1882,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
@@ -1884,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1904,7 +1927,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
}, false)
|
||||
@@ -1912,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1933,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1957,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1967,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1994,7 +2017,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
|
||||
}, false)
|
||||
@@ -2029,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -2057,7 +2080,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
@@ -2070,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -3253,7 +3276,7 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
Status: &nbpeer.PeerStatus{
|
||||
@@ -3282,6 +3305,19 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func drainPeerUpdates(ch <-chan *network_map.UpdateMessage) {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
@@ -3408,7 +3444,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
WireGuardPubKey: account.Peers["peer-1"].Key,
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
@@ -3477,7 +3513,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
@@ -3872,13 +3908,13 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer1")
|
||||
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
}, false)
|
||||
|
||||
117
management/server/affected_peers_coverage_test.go
Normal file
117
management/server/affected_peers_coverage_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestAffectedPeers_DependencyCoverageMatrix enumerates each network-map
|
||||
// dependency crossed with the change-type that can alter it, asserting the
|
||||
// resolver folds in exactly the peers whose map changes. A new dependency that
|
||||
// the resolver fails to walk should fail one of these rows; a new change-type
|
||||
// without a row is a coverage gap to add here.
|
||||
func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
type row struct {
|
||||
name string
|
||||
build func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string)
|
||||
}
|
||||
|
||||
rows := []row{
|
||||
{
|
||||
name: "policy-groups/source-group-change refreshes source+routing, excludes unrelated",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource-routing-bridge/router-peer-change refreshes policy sources",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
|
||||
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "policy-change/explicit-policy refreshes source+routing",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "policy-destinationresource/explicit-policy bridges to routing peer",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource-change refreshes source+routing on its network",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{Resources: []*resourceTypes.NetworkResource{
|
||||
{ID: s.resourceID, NetworkID: s.networkID, GroupIDs: []string{s.resourceGroupID}},
|
||||
}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "network-change refreshes source+routing on that network",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "posture-check-change refreshes source+routing of gated policy",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
Name: "cov-min-version",
|
||||
Checks: posture.ChecksDefinition{NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.SourcePostureChecks = []string{check.ID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{PostureCheckIDs: []string{check.ID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, r := range rows {
|
||||
t.Run(r.name, func(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
change, mustContain, mustExclude := r.build(t, s, ctx)
|
||||
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
|
||||
|
||||
for _, id := range mustContain {
|
||||
assert.Contains(t, affected, id, "expected peer to be affected")
|
||||
}
|
||||
for _, id := range mustExclude {
|
||||
assert.NotContains(t, affected, id, "peer must not be affected")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
143
management/server/affected_peers_oldstate_test.go
Normal file
143
management/server/affected_peers_oldstate_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// An update spans an old and a new state. The affected set must be the UNION of
|
||||
// peers reachable before and after the change; resolving only against the final
|
||||
// state drops peers that were reachable but no longer are. These tests pin the
|
||||
// two paths where the old state is reachable only by the changed object's
|
||||
// previous references: detaching a resource group, and re-pointing a router peer.
|
||||
|
||||
// TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources:
|
||||
// a resource is reachable by a source group via two destination resource groups;
|
||||
// detaching one of them must still refresh that group's policy source peers, even
|
||||
// though the post-update resource no longer maps to it.
|
||||
func TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A second resource group + a second source group/peer that reaches the
|
||||
// resource only through that second group.
|
||||
const detachGroupID = "rs-detach-grp"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: detachGroupID, Name: "rs-detach"}))
|
||||
|
||||
const secondSourceGroupID = "rs-source-grp-2"
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-detach-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
|
||||
}))
|
||||
|
||||
resourcesManager, _, _ := s.managers()
|
||||
|
||||
// Attach the resource to the detach group as well: now in [resourceGroup, detachGroup].
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID, detachGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Policy granting the second source group access via the detach group.
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(secondSourceGroupID, detachGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
secondSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, secondSourcePeer.ID) })
|
||||
settleAffectedUpdates(secondSrcCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Detaching the resource from detachGroup removes the second source's
|
||||
// access; that source peer must be refreshed even though the post-update
|
||||
// resource no longer maps to detachGroup.
|
||||
peerShouldReceiveUpdate(t, secondSrcCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID}, // detached detachGroup
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: detaching a resource group did not refresh the old group's policy source peer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer:
|
||||
// changing router.Peer within the same network must still refresh the OLD routing
|
||||
// peer, which loses its routing role.
|
||||
func TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, routersManager, _ := s.managers()
|
||||
|
||||
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routers, 1)
|
||||
router := routers[0]
|
||||
oldRoutingPeer := router.Peer
|
||||
require.NotEmpty(t, oldRoutingPeer)
|
||||
|
||||
// A new peer to become the routing peer in place of the old one.
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-newrouter-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
newRoutingPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
|
||||
oldCh := s.updateManager.CreateChannel(ctx, oldRoutingPeer)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, oldRoutingPeer) })
|
||||
settleAffectedUpdates(oldCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// The old routing peer stops serving the resource and must be refreshed.
|
||||
peerShouldReceiveUpdate(t, oldCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = routersManager.UpdateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
ID: router.ID,
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: newRoutingPeer.ID, // repoint within the same network
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: re-pointing the router peer did not refresh the old routing peer")
|
||||
}
|
||||
}
|
||||
255
management/server/affected_peers_property_test.go
Normal file
255
management/server/affected_peers_property_test.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// allPeerMaps computes the serialized per-peer network map for every peer in the
|
||||
// account, mirroring the controller's compute path so the property test compares
|
||||
// against real output.
|
||||
func allPeerMaps(t *testing.T, manager *DefaultAccountManager, accountID string) map[string]string {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
validated := make(map[string]struct{}, len(account.Peers))
|
||||
for id := range account.Peers {
|
||||
validated[id] = struct{}{}
|
||||
}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
out := make(map[string]string, len(account.Peers))
|
||||
for peerID := range account.Peers {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validated, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
// Network.Serial is an account-global counter bumped on every change; it
|
||||
// is not a per-peer dependency, so normalize it out of the comparison.
|
||||
if nm.Network != nil {
|
||||
nm.Network.Serial = 0
|
||||
}
|
||||
out[peerID] = canonicalJSON(t, nm)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// canonicalJSON marshals v and returns an order-insensitive string form: every
|
||||
// JSON array is sorted by the canonical form of its elements. The network map's
|
||||
// Peers/Routes/FirewallRules/SourceRanges slices have nondeterministic order, so
|
||||
// a raw JSON compare would report spurious changes.
|
||||
func canonicalJSON(t *testing.T, v interface{}) string {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
var parsed interface{}
|
||||
require.NoError(t, json.Unmarshal(b, &parsed))
|
||||
canonicalized, err := json.Marshal(sortAny(parsed))
|
||||
require.NoError(t, err)
|
||||
return string(canonicalized)
|
||||
}
|
||||
|
||||
func sortAny(v interface{}) interface{} {
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
for i := range val {
|
||||
val[i] = sortAny(val[i])
|
||||
}
|
||||
sort.Slice(val, func(i, j int) bool {
|
||||
bi, _ := json.Marshal(val[i])
|
||||
bj, _ := json.Marshal(val[j])
|
||||
return string(bi) < string(bj)
|
||||
})
|
||||
return val
|
||||
case map[string]interface{}:
|
||||
for k := range val {
|
||||
val[k] = sortAny(val[k])
|
||||
}
|
||||
return val
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// changedPeers returns the peer IDs whose serialized map differs between before
|
||||
// and after.
|
||||
func changedPeers(before, after map[string]string) []string {
|
||||
var changed []string
|
||||
for id, b := range before {
|
||||
a, ok := after[id]
|
||||
if !ok || a != b {
|
||||
changed = append(changed, id)
|
||||
}
|
||||
}
|
||||
for id := range after {
|
||||
if _, ok := before[id]; !ok {
|
||||
changed = append(changed, id)
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// TestAffectedPeers_Property_ResolverSupersetsRealChanges builds a topology,
|
||||
// applies random changes, and asserts that the resolver's affected set is a
|
||||
// superset of the peers whose real network map actually changed. If the resolver
|
||||
// ever misses a dependency, a change will alter a peer's map without that peer
|
||||
// appearing in the affected set, failing here.
|
||||
func TestAffectedPeers_Property_ResolverSupersetsRealChanges(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A pre-existing peer->resource policy so the resource/router bridge is live.
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extra peers and groups to give mutations room to move membership around.
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "prop-key", types.SetupKeyReusable, 0, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
extraPeers := make([]string, 0, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
p := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
extraPeers = append(extraPeers, p.ID)
|
||||
}
|
||||
extraGroups := []string{"prop-grp-0", "prop-grp-1"}
|
||||
for _, g := range extraGroups {
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: g, Name: g}))
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
allGroups := append([]string{s.sourceGroupID, s.resourceGroupID, s.routerPeerGroupID}, extraGroups...)
|
||||
allPeers := append([]string{s.sourcePeerID, s.routerPeerID, s.routerGroupPeerID, s.unrelatedPeerID}, extraPeers...)
|
||||
|
||||
for iter := 0; iter < 60; iter++ {
|
||||
change, apply := s.randomMutation(t, rng, allGroups, allPeers)
|
||||
if apply == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
before := allPeerMaps(t, s.manager, s.accountID)
|
||||
|
||||
resolvedSet := make(map[string]struct{})
|
||||
resolve := func() {
|
||||
require.NoError(t, s.manager.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
|
||||
snap, err := affectedpeers.Load(ctx, tx, s.accountID, change)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, id := range snap.Expand(ctx, s.accountID, change) {
|
||||
resolvedSet[id] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// Resolve on both sides of the mutation and union: removals are visible
|
||||
// only pre-apply (the leaving peer is still a member), additions only
|
||||
// post-apply (the joining peer is now a member). Production captures both
|
||||
// via per-path handling (e.g. UpdateGroup passes peersToRemove); the union
|
||||
// models that without coupling the test to each path's ordering.
|
||||
resolve()
|
||||
changedIDs := change.ChangedPeerIDs
|
||||
apply()
|
||||
resolve()
|
||||
|
||||
after := allPeerMaps(t, s.manager, s.accountID)
|
||||
|
||||
// The explicitly-changed peer's own map refresh is the caller's
|
||||
// responsibility (the resolver returns the peers to propagate to), so it
|
||||
// is allowed to be absent from the resolved set.
|
||||
changedExplicitly := make(map[string]struct{}, len(changedIDs))
|
||||
for _, id := range changedIDs {
|
||||
changedExplicitly[id] = struct{}{}
|
||||
}
|
||||
|
||||
for _, id := range changedPeers(before, after) {
|
||||
if _, stillExists := after[id]; !stillExists {
|
||||
continue
|
||||
}
|
||||
if _, isExplicit := changedExplicitly[id]; isExplicit {
|
||||
continue
|
||||
}
|
||||
_, ok := resolvedSet[id]
|
||||
require.Truef(t, ok,
|
||||
"iter %d: peer %s network map changed but was not in the resolver's affected set %v (change=%+v)",
|
||||
iter, id, maps.Keys(resolvedSet), change)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// randomMutation picks a random change, returns the Change to resolve and a
|
||||
// function that applies the underlying store mutation. apply is nil when the
|
||||
// drawn mutation is a no-op for the current state.
|
||||
func (s *routerScenario) randomMutation(t *testing.T, rng *rand.Rand, allGroups, allPeers []string) (affectedpeers.Change, func()) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
switch rng.Intn(3) {
|
||||
case 0:
|
||||
groupID := allGroups[rng.Intn(len(allGroups))]
|
||||
peerID := allPeers[rng.Intn(len(allPeers))]
|
||||
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
|
||||
require.NoError(t, err)
|
||||
if slicesContains(grp.Peers, peerID) {
|
||||
return affectedpeers.Change{}, nil
|
||||
}
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
|
||||
func() {
|
||||
require.NoError(t, s.manager.GroupAddPeer(ctx, s.accountID, groupID, peerID))
|
||||
}
|
||||
case 1:
|
||||
groupID := allGroups[rng.Intn(len(allGroups))]
|
||||
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
|
||||
require.NoError(t, err)
|
||||
if len(grp.Peers) == 0 {
|
||||
return affectedpeers.Change{}, nil
|
||||
}
|
||||
peerID := grp.Peers[rng.Intn(len(grp.Peers))]
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
|
||||
func() {
|
||||
require.NoError(t, s.manager.GroupDeletePeer(ctx, s.accountID, groupID, peerID))
|
||||
}
|
||||
default:
|
||||
src := allGroups[rng.Intn(len(allGroups))]
|
||||
dst := allGroups[rng.Intn(len(allGroups))]
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: fmt.Sprintf("prop-policy-%d", rng.Int()),
|
||||
Rules: []*types.PolicyRule{{
|
||||
Enabled: true,
|
||||
Sources: []string{src},
|
||||
Destinations: []string{dst},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
}},
|
||||
}
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
func() {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func slicesContains(s []string, v string) bool {
|
||||
for _, x := range s {
|
||||
if x == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
164
management/server/affected_peers_querycount_test.go
Normal file
164
management/server/affected_peers_querycount_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// countingStore wraps a real store and counts the per-account collection loads
|
||||
// the resolver performs, so a test can assert each is read at most once and that
|
||||
// irrelevant collections are skipped entirely.
|
||||
type countingStore struct {
|
||||
store.Store
|
||||
mu sync.Mutex
|
||||
counts map[string]int
|
||||
}
|
||||
|
||||
func newCountingStore(s store.Store) *countingStore {
|
||||
return &countingStore{Store: s, counts: map[string]int{}}
|
||||
}
|
||||
|
||||
func (c *countingStore) bump(name string) {
|
||||
c.mu.Lock()
|
||||
c.counts[name]++
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *countingStore) count(name string) int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.counts[name]
|
||||
}
|
||||
|
||||
func (c *countingStore) total() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
n := 0
|
||||
for _, v := range c.counts {
|
||||
n += v
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountPolicies(ctx context.Context, ls store.LockingStrength, accountID string) ([]*types.Policy, error) {
|
||||
c.bump("policies")
|
||||
return c.Store.GetAccountPolicies(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountRoutes(ctx context.Context, ls store.LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
c.bump("routes")
|
||||
return c.Store.GetAccountRoutes(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountNameServerGroups(ctx context.Context, ls store.LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
c.bump("nameservers")
|
||||
return c.Store.GetAccountNameServerGroups(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountDNSSettings(ctx context.Context, ls store.LockingStrength, accountID string) (*types.DNSSettings, error) {
|
||||
c.bump("dnssettings")
|
||||
return c.Store.GetAccountDNSSettings(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetNetworkRoutersByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
c.bump("routers")
|
||||
return c.Store.GetNetworkRoutersByAccountID(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetNetworkResourcesByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
c.bump("resources")
|
||||
return c.Store.GetNetworkResourcesByAccountID(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountServices(ctx context.Context, ls store.LockingStrength, accountID string) ([]*rpservice.Service, error) {
|
||||
c.bump("services")
|
||||
return c.Store.GetAccountServices(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_NoRedundantFullTableLoads asserts the resolver
|
||||
// loads each per-account collection at most once per Resolve (memoization) even
|
||||
// on a change that drives every bridge, and skips the services table when the
|
||||
// account has no embedded proxy peers.
|
||||
func TestAffectedPeers_QueryCount_NoRedundantFullTableLoads(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
|
||||
// A group change that exercises policies, routers, resources and the bridge.
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
|
||||
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
|
||||
require.NoError(t, err)
|
||||
affected := snap.Expand(ctx, s.accountID, change)
|
||||
assert.Contains(t, affected, s.routerPeerID, "bridge must still resolve the routing peer")
|
||||
|
||||
for _, name := range []string{"policies", "routes", "nameservers", "dnssettings", "routers", "resources"} {
|
||||
assert.LessOrEqualf(t, cs.count(name), 1,
|
||||
"%s must be loaded at most once per Resolve, got %d", name, cs.count(name))
|
||||
}
|
||||
assert.Equal(t, 0, cs.count("services"),
|
||||
"services must not be loaded when the account has no embedded proxy peers")
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads asserts that a change with
|
||||
// no group/peer signal touches no per-account collections beyond what its inputs
|
||||
// require.
|
||||
func TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
|
||||
// A bare network change drives only the router->source bridge: routers and
|
||||
// resources are needed, but routes/nameservers/dnssettings/services are not.
|
||||
_, err := affectedpeers.Load(ctx, cs, s.accountID, affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 0, cs.count("routes"), "routes must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("nameservers"), "nameservers must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("dnssettings"), "dnssettings must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("services"), "services must not be loaded for a network-only change")
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_ExpandReadsNothing is the core invariant of the
|
||||
// Load/Expand split: Load (run inside the transaction) does all store reads;
|
||||
// Expand (run after commit) must touch the store ZERO times, so it never holds
|
||||
// the write lock and never reads post-commit state.
|
||||
func TestAffectedPeers_QueryCount_ExpandReadsNothing(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, cs.total(), 0, "Load must read the store")
|
||||
|
||||
// Any store access during Expand would increment the same counter. Expand
|
||||
// operates purely on the snapshot, so the count must not move.
|
||||
readsAfterLoad := cs.total()
|
||||
affected := snap.Expand(ctx, s.accountID, change)
|
||||
assert.Contains(t, affected, s.routerPeerID, "Expand must still produce the affected peers from the snapshot")
|
||||
assert.Equal(t, readsAfterLoad, cs.total(), "Expand must perform zero store reads — it operates purely on the loaded snapshot")
|
||||
}
|
||||
333
management/server/affected_peers_router_paths_test.go
Normal file
333
management/server/affected_peers_router_paths_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (s *routerScenario) resolveGroupChangeAffected(ctx context.Context, changedGroupIDs []string) []string {
|
||||
change := affectedpeers.Change{ChangedGroupIDs: changedGroupIDs}
|
||||
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, s.accountID, change)
|
||||
}
|
||||
|
||||
func (s *routerScenario) resolvePeerChangeAffected(ctx context.Context, changedPeerIDs []string) []string {
|
||||
change := affectedpeers.Change{ChangedPeerIDs: changedPeerIDs}
|
||||
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, s.accountID, change)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source group member must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"changing the source group of a peer->resource policy must refresh the resource's routing peer")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"changing the source group must refresh the routing peer defined via router.PeerGroups")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_RouterPeerGroupMembership_RefreshesPolicySources(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.routerPeerGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID, "the routing peer itself must be affected")
|
||||
assert.Contains(t, affected, s.sourcePeerID,
|
||||
"changing the router's PeerGroups must refresh the source peers of policies serving the resource")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_SourcePeer_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"a status change on a source peer must refresh the resource's routing peer that serves it")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_SourcePeer_ByDestinationResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"DestinationResource-targeted policy must still bridge a source-peer change to the routing peer")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DeleteGroup_ResolvesAffectedPeers(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const memberOnlyGroupID = "rs-memberonly-grp"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: memberOnlyGroupID, Name: "rs-memberonly", Peers: []string{s.sourcePeerID},
|
||||
}))
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{memberOnlyGroupID})
|
||||
assert.Empty(t, affected, "an unlinked group has no network-map impact, so no peer is affected")
|
||||
|
||||
require.NoError(t, s.manager.DeleteGroup(ctx, s.accountID, userID, memberOnlyGroupID))
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupAddResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const extraResourceGroupID = "rs-resource-grp-extra"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: extraResourceGroupID, Name: "rs-resource-extra",
|
||||
}))
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, extraResourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.manager.GroupAddResource(ctx, s.accountID, extraResourceGroupID, types.Resource{
|
||||
ID: s.resourceID,
|
||||
Type: types.ResourceTypeHost,
|
||||
}))
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{extraResourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"attaching a resource to a policy destination group must refresh the resource's routing peer")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "policy source peers must refresh")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func (s *routerScenario) createPostureCheckGatedPolicy(t *testing.T, ctx context.Context) string {
|
||||
t.Helper()
|
||||
|
||||
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
Name: "rs-min-version",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.SourcePostureChecks = []string{check.ID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
return check.ID
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_SavePostureCheck_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
checkID := s.createPostureCheckGatedPolicy(t, ctx)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
ID: checkID,
|
||||
Name: "rs-min-version",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.31.0"},
|
||||
},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: editing a posture check did not refresh source + routing peers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSourcePeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
resourcesManager, _, _ := s.managers()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/25",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: updating a DestinationResource-targeted resource did not refresh its policy source peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
resourcesManager, routersManager, _ := s.managers()
|
||||
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-disabled", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
disabledRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: disabledRouterPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9000,
|
||||
Enabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
|
||||
|
||||
settleAffectedUpdates(disabledCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, disabledCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/25",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_RouterInOtherNetworkNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "groupiso")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a source-group change for another resource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_RouterInOtherNetworkNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "peeriso")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a source-peer change for another resource")
|
||||
}
|
||||
771
management/server/affected_peers_router_test.go
Normal file
771
management/server/affected_peers_router_test.go
Normal file
@@ -0,0 +1,771 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// routerScenario captures the topology from the bug report:
|
||||
//
|
||||
// network ── router (routing peer) ── resource (in resourceGroup)
|
||||
// independent peer ──(policy: source -> resource)──> resource
|
||||
//
|
||||
// The routing peer must be refreshed when a policy grants a source peer access
|
||||
// to the resource, because the network map connects the source peer to the
|
||||
// routing peer at compute time (Account.GetPoliciesForNetworkResource +
|
||||
// addNetworksRoutingPeers). The routing peer is NOT a member of the resource
|
||||
// group, so static group/peer resolution alone cannot find it.
|
||||
type routerScenario struct {
|
||||
manager *DefaultAccountManager
|
||||
updateManager *update_channel.PeersUpdateManager
|
||||
accountID string
|
||||
networkID string
|
||||
|
||||
sourcePeerID string // independent peer that the policy grants access from
|
||||
sourceGroupID string // group containing the source peer
|
||||
|
||||
routerPeerID string // peer acting as the routing peer (direct router.Peer)
|
||||
routerGroupPeerID string // peer that is a member of routerPeerGroup
|
||||
routerPeerGroupID string // group used for router.PeerGroups
|
||||
|
||||
resourceID string // network resource
|
||||
resourceGroupID string // group whose member is the resource (no peers)
|
||||
|
||||
unrelatedPeerID string // peer in no relevant entity
|
||||
}
|
||||
|
||||
// setupRouterScenario builds the topology above with the default policy removed
|
||||
// and channels NOT yet created, so callers control exactly when updates can flow.
|
||||
func setupRouterScenario(t *testing.T, directRouterPeer bool) *routerScenario {
|
||||
t.Helper()
|
||||
|
||||
manager, updateManager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := createAccount(manager, "router_scenario", userID, "")
|
||||
require.NoError(t, err)
|
||||
accountID := account.Id
|
||||
|
||||
// Remove the default policy so AddPeer/CreateGroup don't schedule unrelated updates.
|
||||
policies, err := manager.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, p := range policies {
|
||||
require.NoError(t, manager.Store.DeletePolicy(ctx, accountID, p.ID))
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(ctx, accountID, "rs-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
sourcePeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
routerPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
routerGroupPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
unrelatedPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
|
||||
const (
|
||||
sourceGroupID = "rs-source-grp"
|
||||
routerPeerGroupID = "rs-router-grp"
|
||||
resourceGroupID = "rs-resource-grp"
|
||||
)
|
||||
|
||||
for _, g := range []*types.Group{
|
||||
{ID: sourceGroupID, Name: "rs-source", Peers: []string{sourcePeer.ID}},
|
||||
{ID: routerPeerGroupID, Name: "rs-router", Peers: []string{routerGroupPeer.ID}},
|
||||
{ID: resourceGroupID, Name: "rs-resource"}, // intentionally peerless; the resource is its only member
|
||||
} {
|
||||
require.NoError(t, manager.CreateGroup(ctx, accountID, userID, g))
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
|
||||
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
|
||||
ID: "rs-network",
|
||||
AccountID: accountID,
|
||||
Name: "rs-network",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
AccountID: accountID,
|
||||
NetworkID: network.ID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
router := &routerTypes.NetworkRouter{
|
||||
ID: "rs-router",
|
||||
NetworkID: network.ID,
|
||||
AccountID: accountID,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
}
|
||||
if directRouterPeer {
|
||||
router.Peer = routerPeer.ID
|
||||
} else {
|
||||
router.PeerGroups = []string{routerPeerGroupID}
|
||||
}
|
||||
_, err = routersManager.CreateRouter(ctx, userID, router)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &routerScenario{
|
||||
manager: manager,
|
||||
updateManager: updateManager,
|
||||
accountID: accountID,
|
||||
networkID: network.ID,
|
||||
sourcePeerID: sourcePeer.ID,
|
||||
sourceGroupID: sourceGroupID,
|
||||
routerPeerID: routerPeer.ID,
|
||||
routerGroupPeerID: routerGroupPeer.ID,
|
||||
routerPeerGroupID: routerPeerGroupID,
|
||||
resourceID: resource.ID,
|
||||
resourceGroupID: resourceGroupID,
|
||||
unrelatedPeerID: unrelatedPeer.ID,
|
||||
}
|
||||
}
|
||||
|
||||
// peerToResourcePolicy builds a policy granting the source group access to the
|
||||
// resource, referencing the resource by its group in the rule destination.
|
||||
func peerToResourcePolicyByGroup(sourceGroupID, resourceGroupID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "peer-to-resource-by-group",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{sourceGroupID},
|
||||
Destinations: []string{resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// peerToResourcePolicyByResource builds a policy referencing the resource
|
||||
// directly via DestinationResource rather than its group.
|
||||
func peerToResourcePolicyByResource(sourceGroupID, resourceID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "peer-to-resource-by-resource",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{sourceGroupID},
|
||||
DestinationResource: types.Resource{ID: resourceID, Type: types.ResourceTypeHost},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolvePolicyAffected mirrors SavePolicy's resolution: resolve the affected
|
||||
// peers for the given policy.
|
||||
func (s *routerScenario) resolvePolicyAffected(ctx context.Context, policy *types.Policy) []string {
|
||||
change := affectedpeers.Change{Policies: []*types.Policy{policy}}
|
||||
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, s.accountID, change)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_SourcePeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_RoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// BUG: the direct routing peer serves the resource's subnet to the source
|
||||
// peer, so it must be refreshed when the policy is created. The policy path
|
||||
// only resolves the literal rule groups (source group + resource group);
|
||||
// the resource group has no peer members and the router peer is reachable
|
||||
// only through the network, so it is dropped.
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"routing peer (router.Peer) serving the resource must be affected by a policy granting access to it")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_RoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// Router defined via PeerGroups instead of a direct peer.
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"routing peer (router.PeerGroups member) serving the resource must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DestResource_RoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// When the resource is referenced via DestinationResource, RuleGroups()
|
||||
// returns only the source group and the resource ID is not a peer, so
|
||||
// collectPolicyAffectedGroupsAndPeers yields nothing for the destination at
|
||||
// all. The routing peer is dropped here too.
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"routing peer must be affected when the resource is referenced via DestinationResource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DestResource_RoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"routing peer (PeerGroups) must be affected when the resource is referenced via DestinationResource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_SourceResourcePeer_RoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// Source expressed as a direct peer (SourceResource), destination as resource group.
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "sourceResource-peer-to-resource",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
SourceResource: types.Resource{ID: s.sourcePeerID, Type: types.ResourceTypePeer},
|
||||
Destinations: []string{s.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// The direct source peer IS picked up (collectPolicyAffectedGroupsAndPeers
|
||||
// handles SourceResource peers), but the routing peer is still missing.
|
||||
assert.Contains(t, affected, s.sourcePeerID, "direct source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID, "routing peer must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResource_UnrelatedPeerNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// Guard against an over-broad fix: a peer in no relevant entity must never
|
||||
// be pulled in.
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_ResourceSideBridgesToRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A pre-existing policy grants the source group access to the resource.
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Drive an update through the resource manager and assert the routing peer
|
||||
// is among the affected set by observing the channel. This path walks
|
||||
// policies whose destinations reference the resource's groups, folds in the
|
||||
// source groups, and loads the network's routers, so it reaches both the
|
||||
// source peer and the routing peer.
|
||||
permissionsManager := permissions.NewManager(s.manager.Store)
|
||||
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
rm := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = rm.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh source peer + routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
// settleAffectedUpdates waits for in-flight async updates to arrive, then drains
|
||||
// every given channel so subsequent assertions start from a clean slate.
|
||||
//
|
||||
// Setup (CreateNetwork/CreateResource/CreateRouter) fires async UpdateAffectedPeers
|
||||
// goroutines; draining first means the assertion only observes updates from the
|
||||
// action under test, not setup stragglers.
|
||||
func settleAffectedUpdates(chans ...<-chan *network_map.UpdateMessage) {
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
for _, ch := range chans {
|
||||
drainPeerUpdates(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: creating peer->resource policy did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: routing peer (PeerGroups) not refreshed on policy create")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DestResource_RoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: routing peer not refreshed when policy targets DestinationResource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DeletePolicy_RoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
require.NoError(t, s.manager.DeletePolicy(ctx, s.accountID, policy.ID, userID))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: deleting peer->resource policy did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *routerScenario) managers() (resources.Manager, routers.Manager, networks.Manager) {
|
||||
permissionsManager := permissions.NewManager(s.manager.Store)
|
||||
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
resourcesManager := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
|
||||
routersManager := routers.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
networksManager := networks.NewManager(s.manager.Store, permissionsManager, resourcesManager, routersManager, s.manager)
|
||||
return resourcesManager, routersManager, networksManager
|
||||
}
|
||||
|
||||
type secondTopology struct {
|
||||
networkID string
|
||||
resourceID string
|
||||
resourceGroupID string
|
||||
routerPeerID string
|
||||
}
|
||||
|
||||
func (s *routerScenario) addSecondTopology(t *testing.T, suffix string) secondTopology {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
resourcesManager, routersManager, networksManager := s.managers()
|
||||
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-"+suffix, types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
routerPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
|
||||
resourceGroupID := "rs-resource-grp-" + suffix
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: resourceGroupID, Name: "rs-resource-" + suffix,
|
||||
}))
|
||||
|
||||
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
|
||||
ID: "rs-network-" + suffix,
|
||||
AccountID: s.accountID,
|
||||
Name: "rs-network-" + suffix,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
AccountID: s.accountID,
|
||||
NetworkID: network.ID,
|
||||
Name: "rs-resource-host-" + suffix,
|
||||
Address: "10.40.50.0/24",
|
||||
GroupIDs: []string{resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: network.ID,
|
||||
AccountID: s.accountID,
|
||||
Peer: routerPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return secondTopology{
|
||||
networkID: network.ID,
|
||||
resourceID: resource.ID,
|
||||
resourceGroupID: resourceGroupID,
|
||||
routerPeerID: routerPeer.ID,
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdatePolicy_BothRoutingPeers(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "b")
|
||||
ctx := context.Background()
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerACh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
routerBCh := s.updateManager.CreateChannel(ctx, second.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, second.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerACh, routerBCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerACh)
|
||||
peerShouldReceiveUpdate(t, routerBCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.Rules[0].Destinations = []string{second.resourceGroupID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: re-pointing the policy destination did not refresh both routing peers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdatePolicy_AddSource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const secondSourceGroupID = "rs-source-grp-2"
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
|
||||
}))
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
newSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, secondSourcePeer.ID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(newSrcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, newSrcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.Rules[0].Sources = []string{s.sourceGroupID, secondSourceGroupID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: adding a source group did not refresh the new source peer + routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DestResource_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: DestinationResource policy with PeerGroups router did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, routersManager, _ := s.managers()
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-r2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: secondRouterPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9998,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "first routing peer must be affected")
|
||||
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routers, 1)
|
||||
routers[0].Enabled = false
|
||||
require.NoError(t, s.manager.Store.UpdateNetworkRouter(ctx, routers[0]))
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
res, err := s.manager.Store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, s.accountID, s.resourceID)
|
||||
require.NoError(t, err)
|
||||
res.Enabled = false
|
||||
require.NoError(t, s.manager.Store.SaveNetworkResource(ctx, res))
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRule(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.Rules[0].Enabled = false
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled rule must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_MultiRule(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "c")
|
||||
ctx := context.Background()
|
||||
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "multi-rule-two-resources",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{s.sourceGroupID},
|
||||
Destinations: []string{s.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{s.sourceGroupID},
|
||||
Destinations: []string{second.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "routing peer for resource A must be affected")
|
||||
assert.Contains(t, affected, second.routerPeerID, "routing peer for resource B must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_RouterOtherNetwork(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "d")
|
||||
ctx := context.Background()
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a policy that does not target its resource")
|
||||
}
|
||||
1802
management/server/affected_peers_test.go
Normal file
1802
management/server/affected_peers_test.go
Normal file
File diff suppressed because it is too large
Load Diff
825
management/server/affectedpeers/resolver.go
Normal file
825
management/server/affectedpeers/resolver.go
Normal file
@@ -0,0 +1,825 @@
|
||||
// Package affectedpeers computes which peers' network maps a change touches, so
|
||||
// only those peers are refreshed instead of the whole account.
|
||||
//
|
||||
// Two phases keep the dependency walk off the write transaction:
|
||||
// - Load: reads the needed collections. Call INSIDE the mutating tx (consistent,
|
||||
// and before a delete/removal severs the old state).
|
||||
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
|
||||
//
|
||||
// Enabled is never consulted: toggling it is itself an observable change.
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// Snapshot is an in-memory view of the collections needed to expand a Change.
|
||||
// Loaded in-tx, walked by Expand after commit. Only the collections the Change
|
||||
// can touch are loaded; the rest stay nil (see Load).
|
||||
type Snapshot struct {
|
||||
policies []*types.Policy
|
||||
routes []*route.Route
|
||||
nsGroups []*nbdns.NameServerGroup
|
||||
dnsSettings *types.DNSSettings
|
||||
routers []*routerTypes.NetworkRouter
|
||||
resources []*resourceTypes.NetworkResource
|
||||
services []*rpservice.Service
|
||||
proxyByCluster map[string][]string
|
||||
groups map[string]*types.Group
|
||||
groupPeers map[string]map[string]struct{} // groupID -> member peer IDs
|
||||
}
|
||||
|
||||
// Load reads the collections a Change requires, inside the caller's tx. It mirrors
|
||||
// Expand's walker preconditions, loading only what the change can touch.
|
||||
func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snapshot, error) {
|
||||
snap := &Snapshot{}
|
||||
if c.isEmpty() {
|
||||
return snap, nil
|
||||
}
|
||||
|
||||
if err := snap.loadCollections(ctx, s, accountID, c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := snap.loadGroupIndex(ctx, s, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return snap, nil
|
||||
}
|
||||
|
||||
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
|
||||
// collections a Change can touch, gated to what the walk needs.
|
||||
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
|
||||
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
|
||||
// the resource<->router bridge can fire for any of these
|
||||
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
|
||||
|
||||
if needsRoutersResources {
|
||||
if err := snap.loadPolicyRoutersResources(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if hasGroupOrPeerChange {
|
||||
if err := snap.loadRoutesAndProxy(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
|
||||
if err := snap.loadDNS(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadPolicyRoutersResources loads the policies plus the routers and resources
|
||||
// the resource<->router bridge walks.
|
||||
func (snap *Snapshot) loadPolicyRoutersResources(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.policies, err = s.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if snap.routers, err = s.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
snap.resources, err = s.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// loadRoutesAndProxy loads the routes and the embedded-proxy services index.
|
||||
func (snap *Snapshot) loadRoutesAndProxy(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.routes, err = s.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
return snap.loadProxyServices(ctx, s, accountID)
|
||||
}
|
||||
|
||||
// loadDNS loads the nameserver groups and account DNS settings.
|
||||
func (snap *Snapshot) loadDNS(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.nsGroups, err = s.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
snap.dnsSettings, err = s.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// loadProxyServices loads the embedded-proxy cluster index, and the services only
|
||||
// when the account actually has embedded proxy peers.
|
||||
func (snap *Snapshot) loadProxyServices(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.proxyByCluster, err = s.GetEmbeddedProxyPeerIDsByCluster(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(snap.proxyByCluster) == 0 {
|
||||
return nil
|
||||
}
|
||||
snap.services, err = s.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// loadGroupIndex loads all groups (for group.Resources) and builds the
|
||||
// group->member-peers index. Always needed: the bridge resolves group.Resources
|
||||
// and Expand maps groups to member peers.
|
||||
func (snap *Snapshot) loadGroupIndex(ctx context.Context, s store.Store, accountID string) error {
|
||||
groups, err := s.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
snap.groups = make(map[string]*types.Group, len(groups))
|
||||
snap.groupPeers = make(map[string]map[string]struct{}, len(groups))
|
||||
for _, g := range groups {
|
||||
snap.groups[g.ID] = g
|
||||
members := make(map[string]struct{}, len(g.Peers))
|
||||
for _, pID := range g.Peers {
|
||||
members[pID] = struct{}{}
|
||||
}
|
||||
snap.groupPeers[g.ID] = members
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change describes what changed in an account.
|
||||
type Change struct {
|
||||
ChangedGroupIDs []string
|
||||
ChangedPeerIDs []string
|
||||
Policies []*types.Policy
|
||||
Routes []*route.Route
|
||||
Routers []*routerTypes.NetworkRouter
|
||||
Resources []*resourceTypes.NetworkResource
|
||||
Networks []*networkTypes.Network
|
||||
PostureCheckIDs []string
|
||||
|
||||
// DistributionGroupIDs are groups whose members are directly affected, with no
|
||||
// dependency walk — the change distributes config to the groups' member peers
|
||||
// only (nameserver groups, DNS DisabledManagementGroups), not through the
|
||||
// policy/route reachability graph. Pass old∪new so both states refresh.
|
||||
DistributionGroupIDs []string
|
||||
|
||||
// RemovedPeersByGroup: peers that left a group, keyed by that group. They are no
|
||||
// longer in the group's member index but still lose its reachability, so they are
|
||||
// folded in — but only when the group is linked (an unlinked group has no map
|
||||
// impact), matching how current members are handled.
|
||||
RemovedPeersByGroup map[string][]string
|
||||
}
|
||||
|
||||
func (c Change) isEmpty() bool {
|
||||
return len(c.ChangedGroupIDs) == 0 &&
|
||||
len(c.ChangedPeerIDs) == 0 &&
|
||||
len(c.Policies) == 0 &&
|
||||
len(c.Routes) == 0 &&
|
||||
len(c.Routers) == 0 &&
|
||||
len(c.Resources) == 0 &&
|
||||
len(c.Networks) == 0 &&
|
||||
len(c.PostureCheckIDs) == 0 &&
|
||||
len(c.DistributionGroupIDs) == 0 &&
|
||||
len(c.RemovedPeersByGroup) == 0
|
||||
}
|
||||
|
||||
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
|
||||
// no store access. Run after the producing tx commits. Logs the full walk at
|
||||
// trace level for diagnosing a miscalculation.
|
||||
func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []string {
|
||||
if c.isEmpty() {
|
||||
return nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
r.walk()
|
||||
return r.expand()
|
||||
}
|
||||
|
||||
// Collect returns the affected group and direct-peer IDs without expanding groups
|
||||
// to members. Test-only introspection; use Resolve otherwise.
|
||||
func Collect(ctx context.Context, s store.Store, accountID string, c Change) (groupIDs []string, directPeerIDs []string) {
|
||||
if c.isEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
snap, err := Load(ctx, s, accountID, c)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load snapshot for affected peers collect: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
r.walk()
|
||||
return setToSlice(r.groupSet), setToSlice(r.peerSet)
|
||||
}
|
||||
|
||||
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
|
||||
r := &resolver{
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
networkIDs: make(map[string]struct{}),
|
||||
}
|
||||
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
|
||||
r.seedChangedGroupsFromPeers()
|
||||
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
|
||||
return r
|
||||
}
|
||||
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
|
||||
// the group-driven walkers fire for memberships, not just direct peer references.
|
||||
func (r *resolver) seedChangedGroupsFromPeers() {
|
||||
if len(r.changedPeerSet) == 0 {
|
||||
return
|
||||
}
|
||||
for groupID, members := range r.snap.groupPeers {
|
||||
for pID := range r.changedPeerSet {
|
||||
if _, ok := members[pID]; ok {
|
||||
r.changedGroupSet[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) walk() {
|
||||
r.collectFromExplicitPolicies()
|
||||
r.collectFromExplicitRoutes(r.change.Routes)
|
||||
r.collectFromExplicitRouters(r.change.Routers)
|
||||
r.collectFromExplicitResources(r.change.Resources)
|
||||
r.collectFromExplicitNetworks(r.change.Networks)
|
||||
r.collectFromPostureChecks(r.change.PostureCheckIDs)
|
||||
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into groupSet so expand() maps them to members, without the policy/
|
||||
// route walk that changedGroupSet would trigger.
|
||||
addAll(r.groupSet, r.change.DistributionGroupIDs)
|
||||
|
||||
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
|
||||
r.collectFromPolicies()
|
||||
r.collectFromRoutes()
|
||||
r.collectFromNameServers()
|
||||
r.collectFromDNSSettings()
|
||||
r.collectFromNetworkRouters()
|
||||
r.collectFromProxyServices()
|
||||
}
|
||||
|
||||
r.collectResourceRouterBridge()
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
ctx context.Context
|
||||
snap *Snapshot
|
||||
accountID string
|
||||
change Change
|
||||
|
||||
changedGroupSet map[string]struct{}
|
||||
changedPeerSet map[string]struct{}
|
||||
|
||||
groupSet map[string]struct{}
|
||||
peerSet map[string]struct{}
|
||||
|
||||
matchedPolicies []*types.Policy
|
||||
networkIDs map[string]struct{}
|
||||
}
|
||||
|
||||
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
|
||||
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
|
||||
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
|
||||
|
||||
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
|
||||
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var ids []string
|
||||
for gID := range groupSet {
|
||||
for pID := range r.snap.groupPeers[gID] {
|
||||
if _, ok := seen[pID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[pID] = struct{}{}
|
||||
ids = append(ids, pID)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (r *resolver) expand() []string {
|
||||
peerIDs := r.peerIDsForGroups(r.groupSet)
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
|
||||
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for id := range r.peerSet {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Fold in removed peers only when their group is linked (in groupSet).
|
||||
for groupID, removed := range r.change.RemovedPeersByGroup {
|
||||
if _, linked := r.groupSet[groupID]; !linked {
|
||||
continue
|
||||
}
|
||||
for _, id := range removed {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: removed peer %s from linked group %s -> affected", id, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand done: account=%s -> %d affected peers: %v", r.accountID, len(peerIDs), peerIDs)
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitPolicies() {
|
||||
for _, policy := range r.matchedPolicies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
|
||||
for _, rt := range routes {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitRouters folds changed routers' peers and marks their networks
|
||||
// for the bridge. Passing the old router keeps a repointed router's previous peers
|
||||
// affected without a post-commit read.
|
||||
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
|
||||
for _, router := range routers {
|
||||
if router == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitResources marks changed resources' networks for the bridge and
|
||||
// treats their group IDs as changed, so policies targeting the resource via a
|
||||
// now-detached (old) group still refresh.
|
||||
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
|
||||
for _, resource := range resources {
|
||||
if resource == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
|
||||
resource.ID, resource.NetworkID, resource.GroupIDs)
|
||||
addAll(r.changedGroupSet, resource.GroupIDs)
|
||||
if resource.NetworkID != "" {
|
||||
r.networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
|
||||
// no groups/peers of its own.
|
||||
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
|
||||
for _, network := range networks {
|
||||
if network == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
|
||||
if network.ID != "" {
|
||||
r.networkIDs[network.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyReferencesPostureChecks(policy, ids) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
|
||||
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromRoutes() {
|
||||
for _, rt := range r.snap.routes {
|
||||
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
|
||||
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNameServers() {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return
|
||||
}
|
||||
for _, ns := range r.snap.nsGroups {
|
||||
if anyInSet(ns.Groups, r.changedGroupSet) {
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
|
||||
addAll(r.groupSet, ns.Groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromDNSSettings() {
|
||||
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
|
||||
return
|
||||
}
|
||||
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := r.changedGroupSet[gID]; ok {
|
||||
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
|
||||
r.groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNetworkRouters() {
|
||||
for _, router := range r.networkRouters() {
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromProxyServices() {
|
||||
if len(r.snap.proxyByCluster) == 0 || len(r.snap.services) == 0 {
|
||||
return
|
||||
}
|
||||
services, proxyByCluster := r.snap.services, r.snap.proxyByCluster
|
||||
|
||||
expanded := r.expandChangedPeersWithGroups()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc == nil {
|
||||
continue
|
||||
}
|
||||
proxyPeers := proxyByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
|
||||
if !matchedByPeer && !matchedByAccessGroup {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
|
||||
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
|
||||
for _, pid := range proxyPeers {
|
||||
r.peerSet[pid] = struct{}{}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
|
||||
r.peerSet[target.TargetId] = struct{}{}
|
||||
}
|
||||
}
|
||||
addAll(r.groupSet, svc.AccessGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
ids := r.peerIDsForGroups(r.changedGroupSet)
|
||||
if len(ids) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
|
||||
for id := range r.changedPeerSet {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
for _, id := range ids {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// collectResourceRouterBridge crosses between source peers and routing peers, which
|
||||
// are reachable only via resource -> network -> router, not through the policy's own
|
||||
// groups: source -> router (targeted resources' networks), then router -> source.
|
||||
func (r *resolver) collectResourceRouterBridge() {
|
||||
r.bridgeSourceToRouters()
|
||||
r.bridgeRoutersToSources()
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeSourceToRouters() {
|
||||
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
networkIDs := r.resourceNetworkIDs(resourceIDs)
|
||||
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
|
||||
setToSlice(resourceIDs), setToSlice(networkIDs))
|
||||
for id := range networkIDs {
|
||||
r.networkIDs[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeRoutersToSources() {
|
||||
if len(r.networkIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
|
||||
setToSlice(r.networkIDs))
|
||||
|
||||
r.foldRoutersOnNetworks(r.networkIDs)
|
||||
|
||||
resourceIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; ok {
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.groupSet, r.peerSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
for _, router := range r.networkRouters() {
|
||||
if _, ok := networkIDs[router.NetworkID]; !ok {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) resourceNetworkIDs(resourceIDs map[string]struct{}) map[string]struct{} {
|
||||
networkIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := resourceIDs[resource.ID]; ok {
|
||||
networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return networkIDs
|
||||
}
|
||||
|
||||
func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[string]struct{}) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
for _, gID := range rule.Destinations {
|
||||
destGroupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(destGroupSet) == 0 {
|
||||
return false
|
||||
}
|
||||
for gID := range destGroupSet {
|
||||
group := r.snap.groups[gID]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
for _, res := range group.Resources {
|
||||
if isInSet(res.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *resolver) policyDestinationResourceIDs(policies ...*types.Policy) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
destGroupSet := collectPolicyDestinations(resourceIDs, policies...)
|
||||
r.addGroupResourceIDs(destGroupSet, resourceIDs)
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
// collectPolicyDestinations adds direct destination resource IDs to resourceIDs and
|
||||
// returns the referenced destination group IDs.
|
||||
func collectPolicyDestinations(resourceIDs map[string]struct{}, policies ...*types.Policy) map[string]struct{} {
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, policy := range policies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(destGroupSet, rule.Destinations)
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
resourceIDs[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return destGroupSet
|
||||
}
|
||||
|
||||
// addGroupResourceIDs folds the resource IDs of the given groups into resourceIDs.
|
||||
func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs map[string]struct{}) {
|
||||
for gID := range groupIDs {
|
||||
group := r.snap.groups[gID]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
for _, res := range group.Resources {
|
||||
if res.ID != "" {
|
||||
resourceIDs[res.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(groupSet, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if _, ok := ids[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
|
||||
if res.Type != types.ResourceTypePeer || res.ID == "" {
|
||||
return false
|
||||
}
|
||||
_, ok := set[res.ID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, changedPeers map[string]struct{}) bool {
|
||||
for _, pid := range proxyPeers {
|
||||
if _, ok := changedPeers[pid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedPeers[target.TargetId]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isInSet(id string, set map[string]struct{}) bool {
|
||||
_, ok := set[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func addAll(set map[string]struct{}, slices ...[]string) {
|
||||
for _, s := range slices {
|
||||
for _, id := range s {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toSet(ids []string) map[string]struct{} {
|
||||
set := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func setToSlice(set map[string]struct{}) []string {
|
||||
s := make([]string, 0, len(set))
|
||||
for id := range set {
|
||||
s = append(s, id)
|
||||
}
|
||||
return s
|
||||
}
|
||||
140
management/server/affectedpeers/resolver_test.go
Normal file
140
management/server/affectedpeers/resolver_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
|
||||
// direct peers) the resolver folds in, for asserting the pure logic.
|
||||
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
|
||||
peerSet := map[string]struct{}{}
|
||||
for _, p := range policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, p.RuleGroups()...)
|
||||
collectPolicyDirectPeers(p, peerSet)
|
||||
}
|
||||
for id := range peerSet {
|
||||
peers = append(peers, id)
|
||||
}
|
||||
return groups, peers
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_Basic(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2", "g3"}, groups)
|
||||
assert.Empty(t, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_WithPeerResources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{ID: "p1", Type: types.ResourceTypePeer},
|
||||
Destinations: []string{"g2"},
|
||||
DestinationResource: types.Resource{ID: "p2", Type: types.ResourceTypePeer},
|
||||
}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
|
||||
assert.ElementsMatch(t, []string{"p1", "p2"}, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_NilPolicy(t *testing.T) {
|
||||
groups, peers := policyGroupsAndPeers(nil)
|
||||
assert.Nil(t, groups)
|
||||
assert.Nil(t, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_MultiplePolicies(t *testing.T) {
|
||||
old := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1"}, Destinations: []string{"g2"}}}}
|
||||
updated := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g3"}, Destinations: []string{"g4"}}}}
|
||||
groups, _ := policyGroupsAndPeers(updated, old)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2", "g3", "g4"}, groups)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_NonPeerResource(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{ID: "domain-1", Type: types.ResourceTypeDomain},
|
||||
Destinations: []string{"g2"},
|
||||
}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
|
||||
assert.Empty(t, peers, "domain resource type should not produce direct peer IDs")
|
||||
}
|
||||
|
||||
func TestChangeIsEmpty(t *testing.T) {
|
||||
assert.True(t, Change{}.isEmpty())
|
||||
assert.False(t, Change{ChangedGroupIDs: []string{"g"}}.isEmpty())
|
||||
assert.False(t, Change{ChangedPeerIDs: []string{"p"}}.isEmpty())
|
||||
assert.False(t, Change{Policies: []*types.Policy{{}}}.isEmpty())
|
||||
assert.False(t, Change{Resources: []*resourceTypes.NetworkResource{{ID: "r"}}}.isEmpty())
|
||||
assert.False(t, Change{Networks: []*networkTypes.Network{{ID: "n"}}}.isEmpty())
|
||||
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
|
||||
}
|
||||
|
||||
func TestPolicyReferencesGroups(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
|
||||
|
||||
assert.True(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc1": {}}))
|
||||
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
|
||||
}
|
||||
|
||||
func TestCollectPolicyDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
|
||||
}, {
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
assert.Contains(t, peerSet, "p2")
|
||||
assert.NotContains(t, peerSet, "r1")
|
||||
}
|
||||
|
||||
func TestCollectPolicySources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
Destinations: []string{"g2"},
|
||||
}}}
|
||||
|
||||
groupSet := map[string]struct{}{}
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicySources(policy, groupSet, peerSet)
|
||||
|
||||
assert.Contains(t, groupSet, "g1")
|
||||
assert.NotContains(t, groupSet, "g2", "destination groups must not be collected as sources")
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
@@ -74,7 +75,10 @@ func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth
|
||||
}
|
||||
|
||||
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
// Child accounts and PAT-authenticated requests do not use JWT group access checks.
|
||||
// Embedded-Dex local users also skip them because local password authentication
|
||||
// does not provide external IdP group claims.
|
||||
if userAuth.IsChild || userAuth.IsPAT || dex.IsLocalUserID(userAuth.UserId) {
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -206,6 +207,43 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
|
||||
_, err = manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
require.Error(t, err, "ensure user access is not in allowed groups")
|
||||
})
|
||||
|
||||
t.Run("Local embedded-Dex user is exempt from JWT allow-groups", func(t *testing.T) {
|
||||
account.Settings.JWTGroupsEnabled = true
|
||||
account.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
account.Settings.JWTAllowGroups = []string{"not-a-group"}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
// Local Dex users have a "local" connector encoded in their user ID.
|
||||
localUserAuth := nbauth.UserAuth{
|
||||
AccountId: account.Id,
|
||||
Domain: domain,
|
||||
UserId: dex.EncodeDexUserID("local-owner", "local"),
|
||||
}
|
||||
|
||||
localUserAuth, err = manager.EnsureUserAccessByJWTGroups(context.Background(), localUserAuth, token)
|
||||
require.NoError(t, err, "local user must not be locked out by JWT allow-groups (issue #5337)")
|
||||
require.Len(t, localUserAuth.Groups, 0, "JWT groups must not be evaluated for local users")
|
||||
})
|
||||
|
||||
t.Run("Federated embedded-Dex user is still subject to JWT allow-groups", func(t *testing.T) {
|
||||
account.Settings.JWTGroupsEnabled = true
|
||||
account.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
account.Settings.JWTAllowGroups = []string{"not-a-group"}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
// A federated user (non-"local" connector) must remain restricted.
|
||||
fedUserAuth := nbauth.UserAuth{
|
||||
AccountId: account.Id,
|
||||
Domain: domain,
|
||||
UserId: dex.EncodeDexUserID("entra-user", "entra"),
|
||||
}
|
||||
|
||||
_, err = manager.EnsureUserAccessByJWTGroups(context.Background(), fedUserAuth, token)
|
||||
require.Error(t, err, "federated user must still be restricted by JWT allow-groups")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -47,8 +48,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
@@ -63,11 +65,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
@@ -75,6 +72,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
change = affectedpeers.Change{DistributionGroupIDs: slices.Concat(addedGroups, removedGroups)}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -85,9 +87,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -133,20 +133,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
|
||||
@@ -298,11 +298,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
|
||||
savedPeer1, _, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
|
||||
_, _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -79,7 +80,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -91,11 +93,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -106,6 +103,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -116,9 +118,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -134,7 +134,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -153,20 +154,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
|
||||
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
if err = syncGroupMembership(ctx, transaction, accountID, newGroup.ID, peersToAdd, peersToRemove); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -178,6 +166,17 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
// A membership change does not alter which entities reference the group, so
|
||||
// the dependency walk runs once against the post-change snapshot. The new
|
||||
// members are already in the snapshot's index; the removed members are
|
||||
// carried separately and folded in only when the group is linked.
|
||||
if len(peersToRemove) > 0 {
|
||||
change.RemovedPeersByGroup = map[string][]string{newGroup.ID: peersToRemove}
|
||||
}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -188,13 +187,26 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncGroupMembership applies the peer membership delta for a group within a transaction.
|
||||
func syncGroupMembership(ctx context.Context, transaction store.Store, accountID, groupID string, peersToAdd, peersToRemove []string) error {
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, groupID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, groupID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
@@ -209,11 +221,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var snaps []*affectedpeers.Snapshot
|
||||
var changes []affectedpeers.Change
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
createdCount := 0
|
||||
for _, newGroup := range groups {
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
var snap *affectedpeers.Snapshot
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
@@ -230,35 +245,31 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
return err
|
||||
}
|
||||
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
return nil
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
|
||||
if len(groupIDs) == 1 {
|
||||
if createdCount == 0 {
|
||||
return err
|
||||
}
|
||||
globalErr = errors.Join(globalErr, err)
|
||||
// continue updating other groups
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
createdCount++
|
||||
snaps = append(snaps, snap)
|
||||
changes = append(changes, change)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
go am.dispatchAffected(ctx, accountID, snaps, changes)
|
||||
|
||||
return globalErr
|
||||
}
|
||||
@@ -277,12 +288,13 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var snaps []*affectedpeers.Snapshot
|
||||
var changes []affectedpeers.Change
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
events, err := am.updateSingleGroup(ctx, accountID, userID, newGroup)
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
events, snap, err := am.updateSingleGroup(ctx, accountID, userID, newGroup, change)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
|
||||
if len(groups) == 1 {
|
||||
@@ -292,27 +304,22 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
continue
|
||||
}
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
snaps = append(snaps, snap)
|
||||
changes = append(changes, change)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
go am.dispatchAffected(ctx, accountID, snaps, changes)
|
||||
|
||||
return globalErr
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) ([]func(), error) {
|
||||
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, change affectedpeers.Change) ([]func(), *affectedpeers.Snapshot, error) {
|
||||
var events []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
@@ -333,9 +340,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
return nil
|
||||
|
||||
var err error
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
return err
|
||||
})
|
||||
return events, err
|
||||
return events, snap, err
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
@@ -438,6 +448,8 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*types.Group
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
|
||||
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -445,26 +457,23 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
deletedGroups, allErrors = collectDeletableGroups(ctx, transaction, accountID, userID, groupIDs, extraSettings.FlowGroups)
|
||||
for _, group := range deletedGroups {
|
||||
groupIDsToDelete = append(groupIDsToDelete, group.ID)
|
||||
}
|
||||
|
||||
if len(groupIDsToDelete) == 0 {
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// Delete: compute affected peers from the PRE-delete state. The groups,
|
||||
// their members and the entities referencing them still exist, so a plain
|
||||
// Load+Expand captures everyone — no removed-peer folding needed.
|
||||
change = affectedpeers.Change{ChangedGroupIDs: groupIDsToDelete}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -483,25 +492,47 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// collectDeletableGroups loads and validates each group for deletion, returning
|
||||
// the groups that may be deleted and the joined validation errors for the rest.
|
||||
func collectDeletableGroups(ctx context.Context, transaction store.Store, accountID, userID string, groupIDs, flowGroups []string) ([]*types.Group, error) {
|
||||
var deletable []*types.Group
|
||||
var allErrors error
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, flowGroups); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
deletable = append(deletable, group)
|
||||
}
|
||||
return deletable, allErrors
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -511,9 +542,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -521,8 +550,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
// GroupAddResource appends resource to the group
|
||||
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
var err error
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
@@ -534,12 +564,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -549,29 +578,32 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{
|
||||
ChangedGroupIDs: []string{groupID},
|
||||
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
// The removed peer is carried in change.RemovedPeersByGroup and folded in
|
||||
// only when the group is linked, so loading post-removal is correct.
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -581,9 +613,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -591,8 +621,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
// GroupDeleteResource removes resource from the group
|
||||
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var snap *affectedpeers.Snapshot
|
||||
var err error
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
@@ -604,8 +635,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
// Load before persisting the removal, so the snapshot still maps the group
|
||||
// to the resource and the bridge can reach its routing peers.
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -619,9 +651,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -832,49 +862,103 @@ func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store,
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
// It fetches each collection once and checks all groupIDs against them in memory.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
groupSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
if affected, err := dnsSettingsReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := nameServersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := policiesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := routesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := networkRoutersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return anyInSet(dnsSettings.DisabledManagementGroups, groupSet), nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
|
||||
func nameServersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if group.HasPeers() || group.HasResources() {
|
||||
for _, ns := range nameServerGroups {
|
||||
if anyInSet(ns.Groups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func policiesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func routesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, r := range routes {
|
||||
if anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func networkRoutersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, router := range routers {
|
||||
if anyInSet(router.PeerGroups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestGroupIPv6Assignment(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"},
|
||||
}, false)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user